import torch
import torch.nn as nn
from torchvision import models

import torch.nn.functional as F

from module import *
import module
from global_var import GlobalVariables


#Below methods to claculate input featurs to the FC layer
#and weight initialization for CNN model is based on the below github repo
#Based on :https://github.com/Lab41/cyphercat/blob/master/Utils/models.py
 
def size_conv(size, kernel, stride=1, padding=0):
    out = int(((size - kernel + 2*padding)/stride) + 1)
    return out
    
    
def size_max_pool(size, kernel, stride=None, padding=0): 
    if stride == None: 
        stride = kernel
    out = int(((size - kernel + 2*padding)/stride) + 1)
    return out

#Calculate in_features for FC layer in Shadow Net
def calc_feat_linear_cifar(size):
    feat = size_conv(size,3,1,1)
    feat = size_max_pool(feat,2,2)
    feat = size_conv(feat,3,1,1)
    out = size_max_pool(feat,2,2)
    return out
    
#Calculate in_features for FC layer in Shadow Net
def calc_feat_linear_mnist(size):
    feat = size_conv(size,5,1)
    feat = size_max_pool(feat,2,2)
    feat = size_conv(feat,5,1)
    out = size_max_pool(feat,2,2)
    return out

#Parameter Initialization
def init_params(m): 
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear): 
        nn.init.xavier_normal_(m.weight.data)
        nn.init.zeros_(m.bias)

#####################################################
# Define Target, Shadow and Attack Model Architecture
#####################################################


#Target Model
class TargetNet(nn.Module):
    def __init__(self, input_dim, hidden_layers, size, out_classes):
        super(TargetNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=input_dim, out_channels=hidden_layers[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_layers[0]),
            # nn.Dropout(p=0.5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=hidden_layers[0], out_channels=hidden_layers[1], kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_layers[1]),
            # nn.Dropout(p=0.5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        features = calc_feat_linear_cifar(size)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear((features**2 * hidden_layers[1]), hidden_layers[2]),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_layers[2], out_classes)
        )
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.classifier(out)
        return out
        
    
#Shadow Model mimicking target model architecture, for our implememtation is different than target
class ShadowNet(nn.Module):
    def __init__(self, input_dim, hidden_layers,size,out_classes):
        super(ShadowNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=input_dim, out_channels=hidden_layers[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_layers[0]),
            # nn.Dropout(p=0.5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=hidden_layers[0], out_channels=hidden_layers[1], kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_layers[1]),
            # nn.Dropout(p=0.5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        features = calc_feat_linear_cifar(size)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear((features**2 * hidden_layers[1]), hidden_layers[2]),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_layers[2], out_classes)
        )
        
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.classifier(out)
        return out

#Pretrained VGG11 model for Target
class VggModel(nn.Module):
    def __init__(self, num_classes,layer_config,pretrained=True):
        super(VggModel, self).__init__()
        #Load the pretrained VGG11_BN model
        if pretrained:
            pt_vgg = models.vgg11_bn(pretrained=pretrained)

            #Deleting old FC layers from pretrained VGG model
            print('### Deleting Avg pooling and FC Layers ####')
            del pt_vgg.avgpool
            del pt_vgg.classifier

            self.model_features = nn.Sequential(*list(pt_vgg.features.children()))
            
            #Adding new FC layers with BN and RELU for CIFAR10 classification
            self.model_classifier = nn.Sequential(
                nn.Linear(layer_config[0], layer_config[1]),
                nn.BatchNorm1d(layer_config[1]),
                nn.ReLU(inplace=True),
                nn.Linear(layer_config[1], num_classes),
            )

    def forward(self, x):
        x = self.model_features(x)
        x = x.squeeze()
        out = self.model_classifier(x)
        return out

#Target/Shadow Model for MNIST
class MNISTNet(nn.Module):
    def __init__(self, input_dim, n_hidden,out_classes=10,size=28):
        super(MNISTNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=input_dim, out_channels=n_hidden, kernel_size=5),
            nn.BatchNorm2d(n_hidden),
            # nn.Dropout(p=0.5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=n_hidden, out_channels=n_hidden*2, kernel_size=5),
            nn.BatchNorm2d(n_hidden*2),
            # nn.Dropout(p=0.5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        features = calc_feat_linear_mnist(size)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(features**2 * (n_hidden*2), n_hidden*2),
            nn.ReLU(inplace=True),
            nn.Linear(n_hidden*2, out_classes)
        )
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.classifier(out)
        return out

#Attack MLP Model
class AttackMLP(nn.Module):
    def __init__(self, input_size, hidden_size=1024,out_classes=2):
        super(AttackMLP, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, out_classes)
        )    
    def forward(self, x):
        out = self.classifier(x)
        return out          


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10): # 这里将类别数设置为10
        
        super(ResNet, self).__init__()


        self.inplanes = 16 # 因为 CIFAR-10 图片较小，所以开始时需要更少的通道数
        GlobalVariables.SELF_INPLANES = self.inplanes
        # print('resnet init:'+ str(GlobalVariables.SELF_INPLANES))
        # 输入层
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()

        # 残差层（4 个阶段，每个阶段包含 6n+2 个卷积层）
        self.layer1 = MakeLayer_ResNet(block, 16, layers[0])
        self.layer2 = MakeLayer_ResNet(block, 32, layers[1], stride=2)
        self.layer3 = MakeLayer_ResNet(block, 64, layers[2], stride=2)
        self.layer4 = MakeLayer_ResNet(block, 128, layers[3], stride=2)

        # 分类层
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128 * block.expansion, num_classes)

        # 参数初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def forward(self, x):
        # 输入层
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        # 这里相比于imagenet的，少了一个maxpool，因为cifar10本身图片就小，如果再pool就太小了

        # 残差层
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # 分类层
        x = self.avgpool(x)  # 输出的尺寸为 B,C,1,1 
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        # out = F.softmax(x,dim = 1)         # 这里不softmax也行 影响不大

        return x

    def quantize(self, quant_type, num_bits=8, e_bits=3):
        self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
        # 没有输入num_bits 需修改
        self.layer1.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)      
        self.layer2.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer3.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer4.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.qavgpool1 = QAdaptiveAvgPool2d(quant_type,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        # self.qfc1 = QLinear(quant_type, self.fc,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)

    def quantize_forward(self, x):
        # for _, layer in self.quantize_layers.items():
        #     x = layer(x)

        # out = F.softmax(x, dim=1)
        # return out
        x = self.qconvbnrelu1(x)
        x = self.layer1.quantize_forward(x)
        x = self.layer2.quantize_forward(x)
        x = self.layer3.quantize_forward(x)
        x = self.layer4.quantize_forward(x)
        x = self.qavgpool1(x)
        x = x.view(x.size(0), -1)   
        x = self.qfc1(x)
        
        # out = F.softmax(x,dim = 1)         # 这里不softmax也行 影响不大
        return x
      

    def freeze(self):
        self.qconvbnrelu1.freeze()  # 因为作为第一层是有qi的，所以freeze的时候无需再重新提供qi
        qo = self.layer1.freeze(qinput = self.qconvbnrelu1.qo)
        qo = self.layer2.freeze(qinput = qo)
        qo = self.layer3.freeze(qinput = qo)
        qo = self.layer4.freeze(qinput = qo)
        self.qavgpool1.freeze(qi=qo)
        self.qfc1.freeze(qi=self.qavgpool1.qo)
        # self.qfc1.freeze()

    def fakefreeze(self):
        self.qconvbnrelu1.fakefreeze()
        self.layer1.fakefreeze()
        self.layer2.fakefreeze()
        self.layer3.fakefreeze()
        self.layer4.fakefreeze()
        self.qfc1.fakefreeze()

    def quantize_inference(self, x):
        qx = self.qconvbnrelu1.qi.quantize_tensor(x)
        qx = self.qconvbnrelu1.quantize_inference(qx)
        qx = self.layer1.quantize_inference(qx)
        qx = self.layer2.quantize_inference(qx)
        qx = self.layer3.quantize_inference(qx)
        qx = self.layer4.quantize_inference(qx)
        qx = self.qavgpool1.quantize_inference(qx)
        qx = qx.view(qx.size(0), -1)
        qx = self.qfc1.quantize_inference(qx) 
        qx = self.qfc1.qo.dequantize_tensor(qx)
        

        # out = F.softmax(qx,dim = 1)         # 这里不softmax也行 影响不大
        return qx




# BasicBlock 类
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()

        # 第一个卷积层
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        # 第二个卷积层
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # shortcut
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride


    def forward(self, x):

        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(identity)

        out += identity
        out = self.relu(out)

        return out
    
    def quantize(self, quant_type ,num_bits=8, e_bits=3):
        self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qconvbn1 = QConvBN(quant_type,self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        
        if self.downsample is not None:
            self.qconvbn2 =  QConvBN(quant_type,self.downsample[0],self.downsample[1],qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
           
        self.qelementadd = QElementwiseAdd(quant_type,qi0=False, qi1=False, qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qrelu1 = QReLU(quant_type,qi= False,num_bits=num_bits,e_bits=e_bits)   # 需要qi
        

    def quantize_forward(self, x):
        identity = x
        out = self.qconvbnrelu1(x)
        out = self.qconvbn1(out)

        if self.downsample is not None:
            identity = self.qconvbn2(identity)
        
        # residual add
        # out = identity + out    # 这里是需要写一个elementwiseadd的变换的，待后续修改
        out = self.qelementadd(out,identity)
        out = self.qrelu1(out)
        return out

    def freeze(self, qinput):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        self.qconvbnrelu1.freeze(qi= qinput)   # 需要接前一个module的最后一个qo
        self.qconvbn1.freeze(qi = self.qconvbnrelu1.qo)

        if self.downsample is not None:
            self.qconvbn2.freeze(qi = qinput) # 一条支路
            self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = self.qconvbn2.qo)
        else:
            self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = qinput)
        # 这里或许需要补充个层来处理elementwise add
        self.qrelu1.freeze(qi = self.qelementadd.qo) 
        return self.qrelu1.qi  # relu后的qo可用relu统计的qi 

    def fakefreeze(self):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        self.qconvbnrelu1.fakefreeze()   # 需要接前一个module的最后一个qo
        self.qconvbn1.fakefreeze()

        if self.downsample is not None:
            self.qconvbn2.fakefreeze() # 一条支路


    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        identity = x
        out = self.qconvbnrelu1.quantize_inference(x)
        out = self.qconvbn1.quantize_inference(out)

        if self.downsample is not None:
            identity = self.qconvbn2.quantize_inference(identity)
        
        # out = identity + out    # 这里可能需要写一个elementwiseadd的变换的，待后续修改
        out = self.qelementadd.quantize_inference(out,identity)
        out = self.qrelu1.quantize_inference(out)
        return out


    


# Bottleneck 类
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()

        # 1x1 卷积层
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        # 3x3 卷积层
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # 1x1 卷积层
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        # shortcut
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):

        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity  # 相加是在这里处理的
        out = self.relu(out)

        return out
    def quantize(self, quant_type ,num_bits=8, e_bits=3):
        self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qconvbnrelu2 = QConvBNReLU(quant_type,self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qconvbn1 = QConvBN(quant_type,self.conv3,self.bn3,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        
        if self.downsample is not None:
            self.qconvbn2 =  QConvBN(quant_type,self.downsample[0],self.downsample[1],qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
           
        self.qelementadd = QElementwiseAdd(quant_type,qi0=False, qi1=False, qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qrelu1 = QReLU(quant_type,qi= False,num_bits=num_bits,e_bits=e_bits)   # 需要qi
    
    def quantize_forward(self, x):
        identity = x
        out = self.qconvbnrelu1(x)
        out = self.qconvbnrelu2(out)
        out = self.qconvbn1(out)

        if self.downsample is not None:
            identity = self.qconvbn2(identity)
        
        # residual add
        # out = identity + out    # 这里是需要写一个elementwiseadd的变换的，待后续修改
        out = self.qelementadd(out,identity)
        out = self.qrelu1(out)
        return out
    
    def freeze(self, qinput):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        self.qconvbnrelu1.freeze(qi= qinput)   # 需要接前一个module的最后一个qo
        self.qconvbnrelu2.freeze(qi=self.qconvbnrelu1.qo)
        self.qconvbn1.freeze(qi = self.qconvbnrelu2.qo)


        if self.downsample is not None:
            self.qconvbn2.freeze(qi = qinput) # 一条支路
            self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = self.qconvbn2.qo)
        else:
            self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = qinput)
        # 这里或许需要补充个层来处理elementwise add
        self.qrelu1.freeze(qi = self.qelementadd.qo)  # 需要自己统计qi
        return self.qrelu1.qi  # relu后的qo可用relu统计的qi 
    
    def fakefreeze(self):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        self.qconvbnrelu1.fakefreeze()   
        self.qconvbnrelu2.fakefreeze()
        self.qconvbn1.fakefreeze()


        if self.downsample is not None:
            self.qconvbn2.fakefreeze() # 一条支路


    
    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        identity = x
        out = self.qconvbnrelu1.quantize_inference(x)
        out = self.qconvbnrelu2.quantize_inference(out)
        out = self.qconvbn1.quantize_inference(out)

        if self.downsample is not None:
            identity = self.qconvbn2.quantize_inference(identity)
        
        # out = identity + out    # 这里可能需要写一个elementwiseadd的变换的，待后续修改
        out = self.qelementadd.quantize_inference(out,identity)
        out = self.qrelu1.quantize_inference(out)
        return out



class MakeLayer_ResNet(nn.Module):

    def __init__(self, block, planes, blocks, stride=1):
        super(MakeLayer_ResNet, self).__init__()
        # print('makelayer init:'+ str(GlobalVariables.SELF_INPLANES))
        self.downsample = None
        if stride != 1 or GlobalVariables.SELF_INPLANES != planes * block.expansion:
            self.downsample = nn.Sequential(
            nn.Conv2d(GlobalVariables.SELF_INPLANES, planes * block.expansion,kernel_size=1, stride=stride, bias=False), 
            nn.BatchNorm2d(planes * block.expansion)
            )
        self.blockdict = nn.ModuleDict()
        self.blockdict['block1'] = block(inplanes=GlobalVariables.SELF_INPLANES, planes=planes, stride=stride, downsample=self.downsample)
        GlobalVariables.SELF_INPLANES = planes * block.expansion
        for i in range(1, blocks):  # block的个数   这里只能用字典了
            self.blockdict['block' + str(i+1)] = block(inplanes=GlobalVariables.SELF_INPLANES, planes=planes)  # 此处进行实例化了
        
    def forward(self,x):
        
        for _, layer in self.blockdict.items():
            x = layer(x)
        
        return x

    def quantize(self, quant_type, num_bits=8, e_bits=3):
        # 需检查
        for _, layer in self.blockdict.items():
            layer.quantize(quant_type=quant_type,num_bits=num_bits,e_bits=e_bits)   # 这里是因为每一块都是block，而block中有具体的quantize策略, n_exp和mode已经在__init__中赋值了
        

    def quantize_forward(self, x):
        for _, layer in self.blockdict.items():
            x = layer.quantize_forward(x)   # 各个block中有具体的quantize_forward

        return x
        
       
    def freeze(self, qinput):  # 需要在 Module Resnet的freeze里传出来
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        cnt = 0
        for _, layer in self.blockdict.items():
            if cnt == 0:
                qo = layer.freeze(qinput = qinput)
                cnt = 1
            else:
                qo = layer.freeze(qinput = qo)  # 各个block中有具体的freeze

        return qo   # 供后续的层用

    def fakefreeze(self):
        for _, layer in self.blockdict.items():
            layer.fakefreeze()

    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        for _, layer in self.blockdict.items():
            x = layer.quantize_inference(x)  # 每个block中有具体的quantize_inference
        
        return x




# 使用 ResNet18 模型
def resnet18(**kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    return model


# 使用 ResNet50 模型
def resnet50(**kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    return model


# 使用 ResNet152 模型
def resnet152(**kwargs):
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    return model




# ==========================================================
# MobileNetV2


class MobileNetV2(nn.Module):
    def __init__(self, num_classes=10):
        super(MobileNetV2, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU6(inplace=True)
         # Bottleneck 层次, t指channel扩充系数
        self.layer1 = MakeLayer_Mobile(32, 16, 1, t=1, stride=1)
        self.layer2 = MakeLayer_Mobile(16, 24, 2, t=6, stride=2)
        self.layer3 = MakeLayer_Mobile(24, 32, 3, t=6, stride=2)
         # 根据CIFAR-10图像大小调整层数
        self.layer4 = MakeLayer_Mobile(32, 96, 3, t=6, stride=1)
        self.layer5 = MakeLayer_Mobile(96, 160, 3, t=6, stride=2)
        self.layer6 = MakeLayer_Mobile(160, 320, 1, t=6, stride=1)

        self.conv2 = nn.Conv2d(320, 1280, 1)
        self.avg1 = nn.AdaptiveAvgPool2d(1)


        self.fc = nn.Linear(1280, num_classes)


    def forward(self, x):
        # x = self.layers(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.conv2(x)
        x = self.avg1(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
    def quantize(self, quant_type, num_bits=8, e_bits=3):
        self.qconvbnrelu1 = QConvBNReLU6(quant_type,self.conv1,self.bn1,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
        # 没有输入num_bits 需修改
        
        self.layer1.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)      
        self.layer2.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer3.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer4.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer5.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer6.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)

        self.qconv1 = QConv2d(quant_type, self.conv2, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
        self.qavgpool1 = QAdaptiveAvgPool2d(quant_type,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)

    def quantize_forward(self, x):
        # for _, layer in self.quantize_layers.items():
        #     x = layer(x)

        # out = F.softmax(x, dim=1)
        # return out
        x = self.qconvbnrelu1(x)
        x = self.layer1.quantize_forward(x)
        x = self.layer2.quantize_forward(x)
        x = self.layer3.quantize_forward(x)
        x = self.layer4.quantize_forward(x)
        x = self.layer5.quantize_forward(x)
        x = self.layer6.quantize_forward(x)
        x = self.qconv1(x)
        x = self.qavgpool1(x)
        x = x.view(x.size(0), -1)   
        x = self.qfc1(x)
        
        out = F.softmax(x,dim = 1)         # 这里不softmax也行 影响不大
        return out

    def freeze(self):
        self.qconvbnrelu1.freeze()  # 因为作为第一层是有qi的，所以freeze的时候无需再重新提供qi
        qo = self.layer1.freeze(qinput = self.qconvbnrelu1.qo)
        qo = self.layer2.freeze(qinput = qo)
        qo = self.layer3.freeze(qinput = qo)
        qo = self.layer4.freeze(qinput = qo)
        qo = self.layer5.freeze(qinput = qo)
        qo = self.layer6.freeze(qinput = qo)
        self.qconv1.freeze(qi = qo)
        self.qavgpool1.freeze(qi=self.qconv1.qo)
        self.qfc1.freeze(qi=self.qavgpool1.qo)
        # self.qfc1.freeze()

    def fakefreeze(self):
        self.qconvbnrelu1.fakefreeze()
        self.layer1.fakefreeze()
        self.layer2.fakefreeze()
        self.layer3.fakefreeze()
        self.layer4.fakefreeze()
        self.layer5.fakefreeze()
        self.layer6.fakefreeze()
        self.qconv1.fakefreeze()
        self.qfc1.fakefreeze()

    
    def quantize_inference(self, x):

        qx = self.qconvbnrelu1.qi.quantize_tensor(x)
        qx = self.qconvbnrelu1.quantize_inference(qx)

        qx = self.layer1.quantize_inference(qx)
        qx = self.layer2.quantize_inference(qx)
        qx = self.layer3.quantize_inference(qx)
        qx = self.layer4.quantize_inference(qx)
        qx = self.layer5.quantize_inference(qx)
        qx = self.layer6.quantize_inference(qx)

        qx = self.qconv1.quantize_inference(qx)
        qx = self.qavgpool1.quantize_inference(qx)
        qx = qx.view(qx.size(0), -1)
        qx = self.qfc1.quantize_inference(qx) 
        qx = self.qfc1.qo.dequantize_tensor(qx)
        
        out = F.softmax(qx,dim = 1)         # 这里不softmax也行 影响不大
        return out



 

class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        hidden_dims = int(in_channels * expand_ratio)
        self.identity_flag = stride == 1 and in_channels == out_channels

        # self.bottleneck = nn.Sequential(
        #     # Pointwise Convolution
        #     nn.Conv2d(in_channels, hidden_dims, 1),
        #     nn.BatchNorm2d(hidden_dims),
        #     nn.ReLU6(inplace=True),
        #     # Depthwise Convolution
        #     nn.Conv2d(hidden_dims, hidden_dims, 3, stride=stride, padding=1, groups=hidden_dims),
        #     nn.BatchNorm2d(hidden_dims),
        #     nn.ReLU6(inplace=True),
        #     # Pointwise & Linear Convolution
        #     nn.Conv2d(hidden_dims, out_channels, 1),
        #     nn.BatchNorm2d(out_channels),
        # )
        self.conv1 = nn.Conv2d(in_channels, hidden_dims, 1)
        self.bn1 = nn.BatchNorm2d(hidden_dims)
        self.relu1 = nn.ReLU6(inplace=True)
        # Depthwise Convolution
        self.conv2 = nn.Conv2d(hidden_dims, hidden_dims, 3, stride=stride, padding=1, groups=hidden_dims)
        self.bn2 = nn.BatchNorm2d(hidden_dims)
        self.relu2 = nn.ReLU6(inplace=True)
        # Pointwise & Linear Convolution
        self.conv3 = nn.Conv2d(hidden_dims, out_channels, 1)
        self.bn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        # if self.identity_flag:
        #     return x + self.bottleneck(x)
        # else:
        #     return self.bottleneck(x)
        identity = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv3(x)
        x = self.bn3(x)

        if self.identity_flag:
            return identity + x
        
        else:
            return x

    
    def quantize(self, quant_type ,num_bits=8, e_bits=3):
        self.qconvbnrelu1 = QConvBNReLU6(quant_type,self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qconvbnrelu2 = QConvBNReLU6(quant_type,self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qconvbn1 = QConvBN(quant_type,self.conv3,self.bn3,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        

        self.qelementadd = QElementwiseAdd(quant_type,qi0=False, qi1=False, qo=True,num_bits=num_bits,e_bits=e_bits)

   
    def quantize_forward(self, x):
        identity = x
        out = self.qconvbnrelu1(x)
        out = self.qconvbnrelu2(out)
        out = self.qconvbn1(out)

        if self.identity_flag:
            out = self.qelementadd(out, identity)

        return out
    
    def freeze(self, qinput):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        self.qconvbnrelu1.freeze(qi= qinput)   # 需要接前一个module的最后一个qo
        self.qconvbnrelu2.freeze(qi=self.qconvbnrelu1.qo)
        self.qconvbn1.freeze(qi = self.qconvbnrelu2.qo)

        if self.identity_flag:
            self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = qinput)
            return self.qelementadd.qo
        else:
            return self.qconvbn1.qo
    
    def fakefreeze(self):
        self.qconvbnrelu1.fakefreeze()
        self.qconvbnrelu2.fakefreeze()
        self.qconvbn1.fakefreeze()


    
    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        identity = x
        out = self.qconvbnrelu1.quantize_inference(x)
        out = self.qconvbnrelu2.quantize_inference(out)
        out = self.qconvbn1.quantize_inference(out)

        if self.identity_flag:
            out = self.qelementadd.quantize_inference(out, identity)        
        
        return out



class MakeLayer_Mobile(nn.Module):

    def __init__(self, in_channels, out_channels, n_repeat, t, stride):
        super(MakeLayer_Mobile, self).__init__()
        # print('makelayer init:'+ str(GlobalVariables.SELF_INPLANES))
        self.layers = nn.ModuleList()
        for i in range(n_repeat):
            if i == 0:
                self.layers.append(InvertedResidual(in_channels, out_channels, stride, t))
            else:
                self.layers.append(InvertedResidual(in_channels, out_channels, 1, t))
            in_channels = out_channels
        
        # for l in self.layers:
        #     print(l)


    def forward(self,x):
        
        for layer in self.layers:
            x = layer(x)
        
        return x

    def quantize(self, quant_type, num_bits=8, e_bits=3):
        # 需检查
        # print('CHECK======')
        for layer in self.layers:
            layer.quantize(quant_type=quant_type,num_bits=num_bits,e_bits=e_bits)   # 这里是因为每一块都是block，而block中有具体的quantize策略, n_exp和mode已经在__init__中赋值了
            
            # print(layer)

        # print('CHECK======')
        

    def quantize_forward(self, x):
        for layer in self.layers:
            x = layer.quantize_forward(x)   # 各个block中有具体的quantize_forward

        return x
        
       
    def freeze(self, qinput):  # 需要在 Module Resnet的freeze里传出来
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        cnt = 0
        for layer in self.layers:
            if cnt == 0:
                qo = layer.freeze(qinput = qinput)
                cnt = 1
            else:
                qo = layer.freeze(qinput = qo)  # 各个block中有具体的freeze

        return qo   # 供后续的层用

    def fakefreeze(self):
        for layer in self.layers:
            layer.fakefreeze()

    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        for layer in self.layers:
            x = layer.quantize_inference(x)  # 每个block中有具体的quantize_inference
        
        return x
