import torch
import torch.nn as nn
import torch.nn.functional as F

from module import *
import module
from global_var import GlobalVariables


# 定义 ResNet 模型
# 适用于Cifar10
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10): # 这里将类别数设置为10
        
        super(ResNet, self).__init__()


        self.inplanes = 16 # 因为 CIFAR-10 图片较小，所以开始时需要更少的通道数
        GlobalVariables.SELF_INPLANES = self.inplanes
        # print('resnet init:'+ str(GlobalVariables.SELF_INPLANES))
        # 输入层
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()

        # 残差层（4 个阶段，每个阶段包含 6n+2 个卷积层）
        self.layer1 = MakeLayer(block, 16, layers[0])
        self.layer2 = MakeLayer(block, 32, layers[1], stride=2)
        self.layer3 = MakeLayer(block, 64, layers[2], stride=2)
        self.layer4 = MakeLayer(block, 128, layers[3], stride=2)

        # 分类层
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128 * block.expansion, num_classes)

        # 参数初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def forward(self, x):
        # 输入层
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        # 这里相比于imagenet的，少了一个maxpool，因为cifar10本身图片就小，如果再pool就太小了

        # 残差层
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # 分类层
        x = self.avgpool(x)  # 输出的尺寸为 B,C,1,1 
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        out = F.softmax(x,dim = 1)         # 这里不softmax也行 影响不大

        return out

    def quantize(self, quant_type, num_bits=8, e_bits=3):
        self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
        # 没有输入num_bits 需修改
        self.layer1.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)      
        self.layer2.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer3.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer4.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.qavgpool1 = QAdaptiveAvgPool2d(quant_type,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        # self.qfc1 = QLinear(quant_type, self.fc,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)

    def quantize_forward(self, x):
        # for _, layer in self.quantize_layers.items():
        #     x = layer(x)

        # out = F.softmax(x, dim=1)
        # return out
        x = self.qconvbnrelu1(x)
        x = self.layer1.quantize_forward(x)
        x = self.layer2.quantize_forward(x)
        x = self.layer3.quantize_forward(x)
        x = self.layer4.quantize_forward(x)
        x = self.qavgpool1(x)
        x = x.view(x.size(0), -1)   
        x = self.qfc1(x)
        
        out = F.softmax(x,dim = 1)         # 这里不softmax也行 影响不大
        return out
      

    def freeze(self):
        self.qconvbnrelu1.freeze()  # 因为作为第一层是有qi的，所以freeze的时候无需再重新提供qi
        qo = self.layer1.freeze(qinput = self.qconvbnrelu1.qo)
        qo = self.layer2.freeze(qinput = qo)
        qo = self.layer3.freeze(qinput = qo)
        qo = self.layer4.freeze(qinput = qo)
        self.qavgpool1.freeze(qi=qo)
        self.qfc1.freeze(qi=self.qavgpool1.qo)
        # self.qfc1.freeze()

    def fakefreeze(self):
        self.qconvbnrelu1.fakefreeze()
        self.layer1.fakefreeze()
        self.layer2.fakefreeze()
        self.layer3.fakefreeze()
        self.layer4.fakefreeze()
        self.qfc1.fakefreeze()

    def quantize_inference(self, x):
        qx = self.qconvbnrelu1.qi.quantize_tensor(x)
        qx = self.qconvbnrelu1.quantize_inference(qx)
        qx = self.layer1.quantize_inference(qx)
        qx = self.layer2.quantize_inference(qx)
        qx = self.layer3.quantize_inference(qx)
        qx = self.layer4.quantize_inference(qx)
        qx = self.qavgpool1.quantize_inference(qx)
        qx = qx.view(qx.size(0), -1)
        qx = self.qfc1.quantize_inference(qx) 
        qx = self.qfc1.qo.dequantize_tensor(qx)
        

        out = F.softmax(qx,dim = 1)         # 这里不softmax也行 影响不大
        return out




# BasicBlock 类
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()

        # 第一个卷积层
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        # 第二个卷积层
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # shortcut
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride


    def forward(self, x):

        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(identity)

        out += identity
        out = self.relu(out)

        return out
    
    def quantize(self, quant_type ,num_bits=8, e_bits=3):
        self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qconvbn1 = QConvBN(quant_type,self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        
        if self.downsample is not None:
            self.qconvbn2 =  QConvBN(quant_type,self.downsample[0],self.downsample[1],qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
           
        self.qelementadd = QElementwiseAdd(quant_type,qi0=False, qi1=False, qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qrelu1 = QReLU(quant_type,qi= False,num_bits=num_bits,e_bits=e_bits)   # 需要qi
        

    def quantize_forward(self, x):
        identity = x
        out = self.qconvbnrelu1(x)
        out = self.qconvbn1(out)

        if self.downsample is not None:
            identity = self.qconvbn2(identity)
        
        # residual add
        # out = identity + out    # 这里是需要写一个elementwiseadd的变换的，待后续修改
        out = self.qelementadd(out,identity)
        out = self.qrelu1(out)
        return out

    def freeze(self, qinput):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        self.qconvbnrelu1.freeze(qi= qinput)   # 需要接前一个module的最后一个qo
        self.qconvbn1.freeze(qi = self.qconvbnrelu1.qo)

        if self.downsample is not None:
            self.qconvbn2.freeze(qi = qinput) # 一条支路
            self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = self.qconvbn2.qo)
        else:
            self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = qinput)
        # 这里或许需要补充个层来处理elementwise add
        self.qrelu1.freeze(qi = self.qelementadd.qo) 
        return self.qrelu1.qi  # relu后的qo可用relu统计的qi 

    def fakefreeze(self):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        self.qconvbnrelu1.fakefreeze()   # 需要接前一个module的最后一个qo
        self.qconvbn1.fakefreeze()

        if self.downsample is not None:
            self.qconvbn2.fakefreeze() # 一条支路


    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        identity = x
        out = self.qconvbnrelu1.quantize_inference(x)
        out = self.qconvbn1.quantize_inference(out)

        if self.downsample is not None:
            identity = self.qconvbn2.quantize_inference(identity)
        
        # out = identity + out    # 这里可能需要写一个elementwiseadd的变换的，待后续修改
        out = self.qelementadd.quantize_inference(out,identity)
        out = self.qrelu1.quantize_inference(out)
        return out


    


# Bottleneck 类
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()

        # 1x1 卷积层
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        # 3x3 卷积层
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # 1x1 卷积层
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        # shortcut
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):

        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity  # 相加是在这里处理的
        out = self.relu(out)

        return out
    def quantize(self, quant_type ,num_bits=8, e_bits=3):
        self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qconvbnrelu2 = QConvBNReLU(quant_type,self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qconvbn1 = QConvBN(quant_type,self.conv3,self.bn3,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        
        if self.downsample is not None:
            self.qconvbn2 =  QConvBN(quant_type,self.downsample[0],self.downsample[1],qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
           
        self.qelementadd = QElementwiseAdd(quant_type,qi0=False, qi1=False, qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qrelu1 = QReLU(quant_type,qi= False,num_bits=num_bits,e_bits=e_bits)   # 需要qi
    
    def quantize_forward(self, x):
        identity = x
        out = self.qconvbnrelu1(x)
        out = self.qconvbnrelu2(out)
        out = self.qconvbn1(out)

        if self.downsample is not None:
            identity = self.qconvbn2(identity)
        
        # residual add
        # out = identity + out    # 这里是需要写一个elementwiseadd的变换的，待后续修改
        out = self.qelementadd(out,identity)
        out = self.qrelu1(out)
        return out
    
    def freeze(self, qinput):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        self.qconvbnrelu1.freeze(qi= qinput)   # 需要接前一个module的最后一个qo
        self.qconvbnrelu2.freeze(qi=self.qconvbnrelu1.qo)
        self.qconvbn1.freeze(qi = self.qconvbnrelu2.qo)


        if self.downsample is not None:
            self.qconvbn2.freeze(qi = qinput) # 一条支路
            self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = self.qconvbn2.qo)
        else:
            self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = qinput)
        # 这里或许需要补充个层来处理elementwise add
        self.qrelu1.freeze(qi = self.qelementadd.qo)  # 需要自己统计qi
        return self.qrelu1.qi  # relu后的qo可用relu统计的qi 
    
    def fakefreeze(self):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        self.qconvbnrelu1.fakefreeze()   
        self.qconvbnrelu2.fakefreeze()
        self.qconvbn1.fakefreeze()


        if self.downsample is not None:
            self.qconvbn2.fakefreeze() # 一条支路


    
    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        identity = x
        out = self.qconvbnrelu1.quantize_inference(x)
        out = self.qconvbnrelu2.quantize_inference(out)
        out = self.qconvbn1.quantize_inference(out)

        if self.downsample is not None:
            identity = self.qconvbn2.quantize_inference(identity)
        
        # out = identity + out    # 这里可能需要写一个elementwiseadd的变换的，待后续修改
        out = self.qelementadd.quantize_inference(out,identity)
        out = self.qrelu1.quantize_inference(out)
        return out



class MakeLayer(nn.Module):

    def __init__(self, block, planes, blocks, stride=1):
        super(MakeLayer, self).__init__()
        # print('makelayer init:'+ str(GlobalVariables.SELF_INPLANES))
        self.downsample = None
        if stride != 1 or GlobalVariables.SELF_INPLANES != planes * block.expansion:
            self.downsample = nn.Sequential(
            nn.Conv2d(GlobalVariables.SELF_INPLANES, planes * block.expansion,kernel_size=1, stride=stride, bias=False), 
            nn.BatchNorm2d(planes * block.expansion)
            )
        self.blockdict = nn.ModuleDict()
        self.blockdict['block1'] = block(inplanes=GlobalVariables.SELF_INPLANES, planes=planes, stride=stride, downsample=self.downsample)
        GlobalVariables.SELF_INPLANES = planes * block.expansion
        for i in range(1, blocks):  # block的个数   这里只能用字典了
            self.blockdict['block' + str(i+1)] = block(inplanes=GlobalVariables.SELF_INPLANES, planes=planes)  # 此处进行实例化了
        
        

    # def _make_layer(self, block, planes, blocks, stride=1):
    #     downsample = None
    #     #  stride 是卷积层的步幅，而 self.inplanes 表示当前残差块输入的通道数，
    #     # planes * block.expansion 则表示当前残差块输出的通道数。因此，当 stride 不等于 1 或者 self.inplanes 不等于 planes * block.expansion 时，就需要进行下采样操作

    #     #该层中除了第一个残差块之外，其他所有残差块的输入通道数和输出通道数都相等，并且具有相同的步幅（都为 1 或者 2）。这些卷积层的输入张量大小不变, 输出张量高宽尺寸会随着残差块的堆叠而逐渐降低
    #     if stride != 1 or SELF_INPLANES != planes * block.expansion:
    #         downsample = nn.Sequential(
    #             nn.Conv2d(SELF_INPLANES, planes * block.expansion,
    #                       kernel_size=1, stride=stride, bias=False),
    #             nn.BatchNorm2d(planes * block.expansion),
    #         )

    #     layers = []
    #     layers.append(block(SELF_INPLANES, planes, stride, downsample))
    #     SELF_INPLANES = planes * block.expansion
    #     for _ in range(1, blocks):  # block的个数 
    #         layers.append(block(SELF_INPLANES, planes))

    #     return nn.Sequential(*layers)
    def forward(self,x):
        
        for _, layer in self.blockdict.items():
            x = layer(x)
        
        return x

    def quantize(self, quant_type, num_bits=8, e_bits=3):
        # 需检查
        for _, layer in self.blockdict.items():
            layer.quantize(quant_type=quant_type,num_bits=num_bits,e_bits=e_bits)   # 这里是因为每一块都是block，而block中有具体的quantize策略, n_exp和mode已经在__init__中赋值了
        

    def quantize_forward(self, x):
        for _, layer in self.blockdict.items():
            x = layer.quantize_forward(x)   # 各个block中有具体的quantize_forward

        return x
        
       
    def freeze(self, qinput):  # 需要在 Module Resnet的freeze里传出来
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        cnt = 0
        for _, layer in self.blockdict.items():
            if cnt == 0:
                qo = layer.freeze(qinput = qinput)
                cnt = 1
            else:
                qo = layer.freeze(qinput = qo)  # 各个block中有具体的freeze

        return qo   # 供后续的层用

    def fakefreeze(self):
        for _, layer in self.blockdict.items():
            layer.fakefreeze()

    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        for _, layer in self.blockdict.items():
            x = layer.quantize_inference(x)  # 每个block中有具体的quantize_inference
        
        return x




# 使用 ResNet18 模型
def resnet18(**kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    return model


# 使用 ResNet50 模型
def resnet50(**kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    return model


# 使用 ResNet152 模型
def resnet152(**kwargs):
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    return model
