# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from new_train import get_children
from global_var import GlobalVariables
from module import *


class LeNet(nn.Module):
    # CONV FLOPs: 考虑bias:（2 * C_in * K_h * K_w )* H_out * W_out * C_out
    #             不考虑bias: （2 * C_in * K_h * K_w -1)* H_out * W_out * C_out
    # FCN FLOPs:  考虑bias: （2 * I ）* O
    #             不考虑bias: (2 * I - 1） * O
    def __init__(self, img_size=32, input_channel=3, num_class=10, n_exp=4, mode=1):
        super().__init__()
        self.conv_layers = nn.ModuleDict({
        # block1
        'conv1': nn.Conv2d(3,6,5),  # (2*3*5*5) * 32*32*6  (bias占其中的32*32*6)  6144/921600
        'reluc1': nn.ReLU(),
        'pool1': nn.MaxPool2d(2,2),

        # block2
        'conv2': nn.Conv2d(6,16,5), # (2*6*5*5) * 16*16*16  (bias占其中的16*16*6) 1536/1228800
        'reluc2': nn.ReLU(),
        'pool2': nn.MaxPool2d(2,2),
        })

        self.fc_layers = nn.ModuleDict({
            # classifier
            'fc1': nn.Linear(16*5*5,120), # (2*16*5*5)*120 (bias占其中的120)  120/96000
            'reluf1': nn.ReLU(),
            'fc2': nn.Linear(120,84), # (2*120)*84  (bias占其中的84)  84/2016
            'reluf2': nn.ReLU(),
            'fc3': nn.Linear(84, num_class)
        })
        self.mode = mode
        self.n_exp = n_exp

    def forward(self,x):

        for _,layer in self.conv_layers.items():
            x = layer(x)

        output = x.view(-1,16*5*5)
        for _,layer in self.fc_layers.items():
            output = layer(output)

        out = F.softmax(output,dim = 1)         # 这里不softmax也行 影响不大
        return out

    def quantize(self, num_bits=8):

        self.quantize_conv_layers=nn.ModuleDict({
            # qi=true: 前一层输出的结果是没有量化过的，需要量化。 maxpool和relu都不会影响INT和minmax，所以在这俩之后的层的pi是false
            #若前一层是conv，数据minmax被改变，则需要qi=true来量化
            'qconv1': QConv2d(self.conv_layers['conv1'], qi=True, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
            'qreluc1': QReLU(n_exp=self.n_exp, mode=self.mode),
            'qpool1': QMaxPooling2d(kernel_size=2,stride=2,padding=0, n_exp=self.n_exp, mode=self.mode),
            'qconv2': QConv2d(self.conv_layers['conv2'], qi=False, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
            'qreluc2': QReLU(n_exp=self.n_exp, mode=self.mode),
            'qpool2': QMaxPooling2d(kernel_size=2, stride=2, padding=0, n_exp=self.n_exp, mode=self.mode)
        })

        self.quantize_fc_layers = nn.ModuleDict({
            'qfc1': QLinear(self.fc_layers['fc1'],qi=False,qo=True,num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
            'qreluf1': QReLU(n_exp=self.n_exp, mode=self.mode),
            'qfc2': QLinear(self.fc_layers['fc2'],qi=False,qo=True,num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
            'qreluf2': QReLU(n_exp=self.n_exp, mode=self.mode),
            'qfc3':  QLinear(self.fc_layers['fc3'],qi=False,qo=True,num_bits=num_bits, n_exp=self.n_exp, mode=self.mode)
        })

    def quantize_forward(self, x):

        for _, layer in self.quantize_conv_layers.items():
            x = layer(x)

        output = x.view(-1,16*5*5)
        for s, layer in self.quantize_fc_layers.items():
            output = layer(output)

        out = F.softmax(output, dim=1)  # 这里不softmax也行 影响不大 算loss用
        return out


    def freeze(self):

        self.quantize_conv_layers['qconv1'].freeze()
        self.quantize_conv_layers['qreluc1'].freeze(self.quantize_conv_layers['qconv1'].qo)
        self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconv1'].qo)

        self.quantize_conv_layers['qconv2'].freeze(self.quantize_conv_layers['qconv1'].qo)
        self.quantize_conv_layers['qreluc2'].freeze(self.quantize_conv_layers['qconv2'].qo)
        self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconv2'].qo)

        self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconv2'].qo)
        self.quantize_fc_layers['qreluf1'].freeze(self.quantize_fc_layers['qfc1'].qo)
        self.quantize_fc_layers['qfc2'].freeze(qi=self.quantize_fc_layers['qfc1'].qo)
        self.quantize_fc_layers['qreluf2'].freeze(self.quantize_fc_layers['qfc2'].qo)
        self.quantize_fc_layers['qfc3'].freeze(qi=self.quantize_fc_layers['qfc2'].qo)

    def fakefreeze(self):
        self.quantize_conv_layers['qconv1'].fakefreeze()
        self.quantize_conv_layers['qreluc1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
        self.quantize_conv_layers['qpool1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)

        self.quantize_conv_layers['qconv2'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
        self.quantize_conv_layers['qreluc2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo)
        self.quantize_conv_layers['qpool2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo)

        self.quantize_fc_layers['qfc1'].fakefreeze(qi=self.quantize_conv_layers['qconv2'].qo)
        self.quantize_fc_layers['qreluf1'].fakefreeze(self.quantize_fc_layers['qfc1'].qo)
        self.quantize_fc_layers['qfc2'].fakefreeze(qi=self.quantize_fc_layers['qfc1'].qo)
        self.quantize_fc_layers['qreluf2'].fakefreeze(self.quantize_fc_layers['qfc2'].qo)
        self.quantize_fc_layers['qfc3'].fakefreeze(qi=self.quantize_fc_layers['qfc2'].qo)

    def quantize_inference(self, x):
        x = self.quantize_conv_layers['qconv1'].qi.quantize_tensor(x,  self.mode)

        for s, layer in self.quantize_conv_layers.items():
            x = layer.quantize_inference(x)

        output = x.view( -1,16*5*5)

        for s, layer in self.quantize_fc_layers.items():
            output = layer.quantize_inference(output)

        # 只有mode1需要出现范围映射，将量化后的数据恢复到原数据相似的范围，PoT无需，其自带恢复性
        if self.mode == 1:
            output = self.quantize_fc_layers['qfc3'].qo.dequantize_tensor(output, self.mode)

        out = F.softmax(output, dim=1)  # 这里应该用 Qsoftmax可能好些 之后改改
        return out







class NetBN(nn.Module):

    def __init__(self, num_channels=1):
        super(NetBN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.bn1 = nn.BatchNorm2d(40)
        self.conv2 = nn.Conv2d(40, 40, 3, 1)
        self.bn2 = nn.BatchNorm2d(40)
        self.fc = nn.Linear(5 * 5 * 40, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5 * 5 * 40)
        x = self.fc(x)
        return x

    def quantize(self, num_bits=8):
        self.qconv1 = QConvBNReLU(self.conv1, self.bn1, qi=True, qo=True, num_bits=num_bits)
        self.qmaxpool2d_1 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qconv2 = QConvBNReLU(self.conv2, self.bn2, qi=False, qo=True, num_bits=num_bits)
        self.qmaxpool2d_2 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qfc = QLinear(self.fc, qi=False, qo=True, num_bits=num_bits)

    def quantize_forward(self, x):
        x = self.qconv1(x)
        x = self.qmaxpool2d_1(x)
        x = self.qconv2(x)
        x = self.qmaxpool2d_2(x)
        x = x.view(-1, 5*5*40)
        x = self.qfc(x)
        return x

    def freeze(self):
        self.qconv1.freeze()
        self.qmaxpool2d_1.freeze(self.qconv1.qo)
        self.qconv2.freeze(qi=self.qconv1.qo)  # 因为maxpool不会改变min，max
        self.qmaxpool2d_2.freeze(self.qconv2.qo)
        self.qfc.freeze(qi=self.qconv2.qo)  # 因为maxpool不会改变min，max

    def quantize_inference(self, x):
        qx = self.qconv1.qi.quantize_tensor(x)
        qx = self.qconv1.quantize_inference(qx)
        qx = self.qmaxpool2d_1.quantize_inference(qx)
        qx = self.qconv2.quantize_inference(qx)
        qx = self.qmaxpool2d_2.quantize_inference(qx)
        qx = qx.view(-1, 5*5*40)

        qx = self.qfc.quantize_inference(qx)
        
        out = self.qfc.qo.dequantize_tensor(qx)   # INT -> FP
        return out




# 定义 ResNet 模型
# 适用于Cifar10
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10, n_exp=4, mode=1): # 这里将类别数设置为10
        
        super(ResNet, self).__init__()

        self.inplanes = 16 # 因为 CIFAR-10 图片较小，所以开始时需要更少的通道数
        GlobalVariables.SELF_INPLANES = self.inplanes
        print('resnet init:'+ str(GlobalVariables.SELF_INPLANES))
        # 输入层
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()

        # 残差层（4 个阶段，每个阶段包含 6n+2 个卷积层）
        # self.layer1 = self._make_layer(block, 16, layers[0])
        # self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        # self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        # self.layer4 = self._make_layer(block, 128, layers[3], stride=2)
        self.layer1 = MakeLayer(block, 16, layers[0])
        self.layer2 = MakeLayer(block, 32, layers[1], stride=2)
        self.layer3 = MakeLayer(block, 64, layers[2], stride=2)
        self.layer4 = MakeLayer(block, 128, layers[3], stride=2)

        # 分类层
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128 * block.expansion, num_classes)

        # self.layers_to_quantize = [self.conv1, self.bn1, self.relu, self.layer1, self.layer2, self.layer3, self.layer4, self.avgpool, self.fc]


        # 参数初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    # 似乎对于resnet的self.inplanes在不断被改变，传递下去
    # def _make_layer(self, block, planes, blocks, stride=1):
    #     downsample = None
    #     #  stride 是卷积层的步幅，而 self.inplanes 表示当前残差块输入的通道数，
    #     # planes * block.expansion 则表示当前残差块输出的通道数。因此，当 stride 不等于 1 或者 self.inplanes 不等于 planes * block.expansion 时，就需要进行下采样操作

    #     #该层中除了第一个残差块之外，其他所有残差块的输入通道数和输出通道数都相等，并且具有相同的步幅（都为 1 或者 2）。这些卷积层的输入张量大小不变, 输出张量高宽尺寸会随着残差块的堆叠而逐渐降低
    #     if stride != 1 or self.inplanes != planes * block.expansion:
    #         downsample = nn.Sequential(
    #             nn.Conv2d(self.inplanes, planes * block.expansion,
    #                       kernel_size=1, stride=stride, bias=False),
    #             nn.BatchNorm2d(planes * block.expansion),
    #         )

    #     layers = []
    #     layers.append(block(self.inplanes, planes, stride, downsample))
    #     self.inplanes = planes * block.expansion
    #     for _ in range(1, blocks):  # block的个数 
    #         layers.append(block(self.inplanes, planes))

    #     return nn.Sequential(*layers)
    


    def forward(self, x):
        # 输入层
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        # 这里相比于imagenet的，少了一个maxpool，因为cifar10本身图片就小，如果再pool就太小了

        # 残差层
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # 分类层
        x = self.avgpool(x)  # 输出的尺寸为 B,C,1,1 
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

    def quantize(self, num_bits=8):
      pass

    def quantize_forward(self, x):
        # for _, layer in self.quantize_layers.items():
        #     x = layer(x)

        # out = F.softmax(x, dim=1)
        # return out
        pass
                  

    def freeze(self):
        pass

    def fakefreeze(self):
        pass

    def quantize_inference(self, x):
        pass


# BasicBlock 类
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()

        # 第一个卷积层
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        # 第二个卷积层
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # shortcut
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):

        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(identity)

        out += identity
        out = self.relu(out)

        return out
    
    def quantize(self, num_bits=8):
        self.qconvbnrelu1 = QConvBNReLU(self.conv1,self.bn2,qi=False,qo=True,num_bits=num_bits)
        self.qconvbn1 = QConvBN(self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits)

        if self.downsample is not None:
            self.qconvbn2 =  QConvBN(self.downsample[0],self.downsample[1],qi=False,qo=True,num_bits=num_bits)

        self.qrelu1 = QReLU()
        

    def quantize_forward(self, x):
        identity = x
        out = self.qconvbnrelu1(x)
        out = self.qconvbn1(out)

        if self.downsample is not None:
            identity = self.qconvbn2(identity)
        
        # residual add
        out = identity + out    # 这里是需要写一个elementwiseadd的变换的，待后续修改
        out = self.qrelu1(out)
        return out

    def freeze(self):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        self.qconvbnrelu1.freeze()
        self.qconvbn1.freeze(qi = self.qconvbnrelu1.qo)

        if self.downsample is not None:
            self.qconvbn2.freeze(qi = self.qconvbn1)
            self.qrelu1.freeze(self.qconvbn2)

        else:
            self.qrelu1.freeze(self.qconvbn1)

    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        identity = x
        out = self.qconvbnrelu1.quantize_inference(x)
        out = self.qconvbn1.quantize_inference(x)

        if self.downsample is not None:
            identity = self.qconvbn2.quantize_inference(identity)
        
        out = identity + out    # 这里是需要写一个elementwiseadd的变换的，待后续修改
        out = self.qrelu1.quantize_inference(out)
        return out


    


# Bottleneck 类
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()

        # 1x1 卷积层
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        # 3x3 卷积层
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # 1x1 卷积层
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        # shortcut
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):

        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity  # 相加是在这里处理的
        out = self.relu(out)

        return out


class MakeLayer(nn.Module):

    def __init__(self, block, planes, blocks, stride=1):
        super(MakeLayer, self).__init__()
        print('makelayer init:'+ str(GlobalVariables.SELF_INPLANES))
        self.downsample = None
        if stride != 1 or GlobalVariables.SELF_INPLANES != planes * block.expansion:
            self.downsample = nn.Sequential(
            nn.Conv2d(GlobalVariables.SELF_INPLANES, planes * block.expansion,kernel_size=1, stride=stride, bias=False), 
            nn.BatchNorm2d(planes * block.expansion)
            )
        
        self.blockdict = nn.ModuleDict()
        self.blockdict['block1'] = block(GlobalVariables.SELF_INPLANES, planes, stride, self.downsample)
        GlobalVariables.SELF_INPLANES = planes * block.expansion
        for i in range(1, blocks):  # block的个数   这里只能用字典了
            self.blockdict['block' + str(i+1)] = block(GlobalVariables.SELF_INPLANES, planes)  # 此处进行实例化了
            
    # def _make_layer(self, block, planes, blocks, stride=1):
    #     downsample = None
    #     #  stride 是卷积层的步幅，而 self.inplanes 表示当前残差块输入的通道数，
    #     # planes * block.expansion 则表示当前残差块输出的通道数。因此，当 stride 不等于 1 或者 self.inplanes 不等于 planes * block.expansion 时，就需要进行下采样操作

    #     #该层中除了第一个残差块之外，其他所有残差块的输入通道数和输出通道数都相等，并且具有相同的步幅（都为 1 或者 2）。这些卷积层的输入张量大小不变, 输出张量高宽尺寸会随着残差块的堆叠而逐渐降低
    #     if stride != 1 or SELF_INPLANES != planes * block.expansion:
    #         downsample = nn.Sequential(
    #             nn.Conv2d(SELF_INPLANES, planes * block.expansion,
    #                       kernel_size=1, stride=stride, bias=False),
    #             nn.BatchNorm2d(planes * block.expansion),
    #         )

    #     layers = []
    #     layers.append(block(SELF_INPLANES, planes, stride, downsample))
    #     SELF_INPLANES = planes * block.expansion
    #     for _ in range(1, blocks):  # block的个数 
    #         layers.append(block(SELF_INPLANES, planes))

    #     return nn.Sequential(*layers)
    def forward(self,x):
        
        for _, layer in self.blockdict.items():
            x = layer(x)
        
        return x

    def quantize(self, num_bits=8):
        # 需检查
        for _, layer in self.blockdict.items():
            layer.quantize()   # 这里是因为每一块都是block，而block中有具体的quantize策略
        

    def quantize_forward(self, x):
        for _, layer in self.blockdict.items():
            x = layer.quantize_forward(x)   # 各个block中有具体的quantize_forward

        return x
        
       
    def freeze(self):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        for _, layer in self.blockdict.items():
            layer.freeze()  # 各个block中有具体的freeze


    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        for _, layer in self.blockdict.items():
            x = layer.quantize_inference(x)  # 每个block中有具体的quantize_inference
        
        return x




# 使用 ResNet18 模型
def resnet18(**kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    return model


# 使用 ResNet50 模型
def resnet50(**kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    return model


# 使用 ResNet152 模型
def resnet152(**kwargs):
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    return model