# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

from module import *

# class VGG_19(nn.Module):
#     def __init__(self, img_size=32, input_channel=3, num_class=10):
#         super().__init__()
#         self.conv_param_layer_name = (
#             'conv1_1', 'relu1_1', 'conv1_2', 'bn1_1', 'relu1_2', 'pool1',
#             'conv2_1', 'bn2_1', 'relu2_1', 'conv2_2', 'bn2_2', 'relu2_2', 'pool2',
#             'conv3_1', 'bn3_1', 'relu3_1', 'conv3_2', 'bn3_2', 'relu3_2', 'conv3_3', 'bn3_3', 'relu3_3', 'conv3_4',
#             'bn3_4', 'relu3_4', 'pool3',
#             'conv4_1', 'bn4_1', 'relu4_1', 'conv4_2', 'bn4_2', 'relu4_2', 'conv4_3', 'bn4_3', 'relu4_3', 'conv4_4',
#             'bn4_4', 'relu4_4', 'pool4',
#             'conv5_1', 'bn5_1', 'relu5_1', 'conv5_2', 'bn5_2', 'relu5_2', 'conv5_3', 'bn5_3', 'relu5_3', 'conv5_4',
#             'bn5_4', 'relu5_4', 'pool5'
#         )

#         self.fc_param_layer_name = (
#             'fc1','relu1','drop1','fc2','relu2','drop2','fc3'
#         )

#         self.conv_layers = nn.ModuleDict({
#             # block1
#         'conv1_1': nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'relu1_1': nn.ReLU(),
#         'conv1_2': nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn1_1': nn.BatchNorm2d(num_features=64),
#         'relu1_2': nn.ReLU(),
#         'pool1': nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),

#         # block2
#         'conv2_1': nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn2_1': nn.BatchNorm2d(num_features=128),
#         'relu2_1': nn.ReLU(),
#         'conv2_2': nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn2_2': nn.BatchNorm2d(num_features=128),
#         'relu2_2': nn.ReLU(),
#         'pool2': nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),

#         # block3
#         'conv3_1': nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn3_1': nn.BatchNorm2d(num_features=256),
#         'relu3_1':  nn.ReLU(),

#         'conv3_2': nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn3_2':nn.BatchNorm2d(num_features=256),
#         'relu3_2': nn.ReLU(),

#         'conv3_3': nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn3_3': nn.BatchNorm2d(num_features=256),
#         'relu3_3': nn.ReLU(),

#         'conv3_4': nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn3_4': nn.BatchNorm2d(num_features=256),
#         'relu3_4': nn.ReLU(),
#         'pool3': nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),

#         # block4
#         'conv4_1': nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn4_1': nn.BatchNorm2d(num_features=512),
#         'relu4_1': nn.ReLU(),

#         'conv4_2': nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn4_2': nn.BatchNorm2d(num_features=512),
#         'relu4_2': nn.ReLU(),

#         'conv4_3': nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn4_3': nn.BatchNorm2d(num_features=512),
#         'relu4_3': nn.ReLU(),

#         'conv4_4': nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn4_4': nn.BatchNorm2d(num_features=512),
#         'relu4_4': nn.ReLU(),

#         'pool4': nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),

#         # block5
#         'conv5_1': nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn5_1': nn.BatchNorm2d(num_features=512),
#         'relu5_1': nn.ReLU(),

#         'conv5_2': nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn5_2': nn.BatchNorm2d(num_features=512),
#         'relu5_2': nn.ReLU(),

#         'conv5_3': nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn5_3': nn.BatchNorm2d(num_features=512),
#         'relu5_3': nn.ReLU(),

#         'conv5_4': nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
#         'bn5_4': nn.BatchNorm2d(num_features=512),
#         'relu5_4': nn.ReLU(),

#         'pool5': nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

#         })
#         self.fc_layers = nn.ModuleDict({
#             # classifier
#             'fc1': nn.Linear(512 * (int)(img_size * img_size / 32 / 32), 4096),
#             'relu1': nn.ReLU(),
#             'drop1': nn.Dropout(0.5),
#             'fc2': nn.Linear(4096, 4096),
#             'relu2': nn.ReLU(),
#             'drop2': nn.Dropout(0.5),
#             'fc3': nn.Linear(4096, num_class)
#         })


#     def forward(self,x):

#         for _,layer in self.conv_layers.items():
#             x = layer(x)

#         output = x.view(x.size()[0], -1)
#         for _,layer in self.fc_layers.items():
#             output = layer(output)

#         out = F.softmax(output,dim = 1)         # 这里不softmax也行 影响不大
#         return out

#     def quantize(self, num_bits=8):

#         self.quantize_conv_layers=nn.ModuleDict({
#             # qi=true: 前一层输出的结果是没有量化过的，需要量化。 maxpool和relu都不会影响INT和minmax，所以在这俩之后的层的pi是false
#             #若前一层是conv，数据minmax被改变，则需要qi=true来量化
#             'qconv1_1': QConv2d(self.conv_layers['conv1_1'], qi=True, qo=True, num_bits=num_bits),
#             'qrelu1_1': QReLU(),
#             'qconvbnrelu1_1': QConvBNReLU(self.conv_layers['conv1_2'],self.conv_layers['bn1_1'],qi=False,qo=True,num_bits=num_bits),
#             'qpool1': QMaxPooling2d(kernel_size=2,stride=2,padding=0),

#             # block2
#             'qconvbnrelu2_1': QConvBNReLU(self.conv_layers['conv2_1'], self.conv_layers['bn2_1'], qi=False, qo=True, num_bits=num_bits),
#             'qconvbnrelu2_2': QConvBNReLU(self.conv_layers['conv2_2'], self.conv_layers['bn2_2'], qi=False, qo=True, num_bits=num_bits),
#             'qpool2': QMaxPooling2d(kernel_size=2,stride=2,padding=0),

#             # block3
#             'qconvbnrelu3_1': QConvBNReLU(self.conv_layers['conv3_1'], self.conv_layers['bn3_1'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qconvbnrelu3_2': QConvBNReLU(self.conv_layers['conv3_2'], self.conv_layers['bn3_2'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qconvbnrelu3_3': QConvBNReLU(self.conv_layers['conv3_3'], self.conv_layers['bn3_3'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qconvbnrelu3_4': QConvBNReLU(self.conv_layers['conv3_4'], self.conv_layers['bn3_4'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qpool3':  QMaxPooling2d(kernel_size=2,stride=2,padding=0),

#             # block4
#             'qconvbnrelu4_1': QConvBNReLU(self.conv_layers['conv4_1'], self.conv_layers['bn4_1'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qconvbnrelu4_2': QConvBNReLU(self.conv_layers['conv4_2'], self.conv_layers['bn4_2'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qconvbnrelu4_3': QConvBNReLU(self.conv_layers['conv4_3'], self.conv_layers['bn4_3'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qconvbnrelu4_4': QConvBNReLU(self.conv_layers['conv4_4'], self.conv_layers['bn4_4'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qpool4': QMaxPooling2d(kernel_size=2,stride=2,padding=0),

#             # block5
#             'qconvbnrelu5_1': QConvBNReLU(self.conv_layers['conv5_1'], self.conv_layers['bn5_1'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qconvbnrelu5_2': QConvBNReLU(self.conv_layers['conv5_2'], self.conv_layers['bn5_2'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qconvbnrelu5_3': QConvBNReLU(self.conv_layers['conv5_3'], self.conv_layers['bn5_3'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qconvbnrelu5_4': QConvBNReLU(self.conv_layers['conv5_4'], self.conv_layers['bn5_4'], qi=False, qo=True,
#                                           num_bits=num_bits),
#             'qpool5': QMaxPooling2d(kernel_size=2,stride=2,padding=0)
#         })

#         self.quantize_fc_layers = nn.ModuleDict({
#             'qfc1': QLinear(self.fc_layers['fc1'],qi=False,qo=True,num_bits=num_bits),
#             'qrelu1': QReLU(),
#             'qdrop1': nn.Dropout(0.5),
#             'qfc2': QLinear(self.fc_layers['fc2'],qi=False,qo=True,num_bits=num_bits),
#             'qrelu2': QReLU(),
#             'qdrop2': nn.Dropout(0.5),
#             'qfc3':  QLinear(self.fc_layers['fc3'],qi=False,qo=True,num_bits=num_bits)
#         })

#     def quantize_forward(self, x):

#         for _, layer in self.quantize_conv_layers.items():
#             x = layer(x)

#         output = x.view(x.size()[0],-1)
#         for s, layer in self.quantize_fc_layers.items():
#             # if (s=='qrelu1') == True or (s=='qrelu2')==True:
#             #     output = nn.Dropout(0.5)
#             # else:
#             output = layer(output)

#         out = F.softmax(output, dim=1)  # 这里不softmax也行 影响不大 算loss用
#         return out


#     def freeze(self):

#         self.quantize_conv_layers['qconv1_1'].freeze()
#         self.quantize_conv_layers['qrelu1_1'].freeze(self.quantize_conv_layers['qconv1_1'].qo)
#         self.quantize_conv_layers['qconvbnrelu1_1'].freeze(qi=self.quantize_conv_layers['qconv1_1'].qo)
#         #self.quantize_conv_layers['qconvbnrelu1_1'].freeze(qi=self.quantize_conv_layers['qrelu1_1'].qo)
#         self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconvbnrelu1_1'].qo)


#         self.quantize_conv_layers['qconvbnrelu2_1'].freeze(qi=self.quantize_conv_layers['qconvbnrelu1_1'].qo)
#         self.quantize_conv_layers['qconvbnrelu2_2'].freeze(qi=self.quantize_conv_layers['qconvbnrelu2_1'].qo)
#         self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconvbnrelu2_2'].qo)

#         self.quantize_conv_layers['qconvbnrelu3_1'].freeze(qi=self.quantize_conv_layers['qconvbnrelu2_2'].qo)
#         #self.quantize_conv_layers['qconvbnrelu3_1'].freeze(qi=self.quantize_conv_layers['qpool2'].qo)
#         self.quantize_conv_layers['qconvbnrelu3_2'].freeze(qi=self.quantize_conv_layers['qconvbnrelu3_1'].qo)
#         self.quantize_conv_layers['qconvbnrelu3_3'].freeze(qi=self.quantize_conv_layers['qconvbnrelu3_2'].qo)
#         self.quantize_conv_layers['qconvbnrelu3_4'].freeze(qi=self.quantize_conv_layers['qconvbnrelu3_3'].qo)
#         self.quantize_conv_layers['qpool3'].freeze(self.quantize_conv_layers['qconvbnrelu3_4'].qo)

#         self.quantize_conv_layers['qconvbnrelu4_1'].freeze(qi=self.quantize_conv_layers['qconvbnrelu3_4'].qo)
#         #self.quantize_conv_layers['qconvbnrelu4_1'].freeze(qi=self.quantize_conv_layers['qpool3'].qo)
#         self.quantize_conv_layers['qconvbnrelu4_2'].freeze(qi=self.quantize_conv_layers['qconvbnrelu4_1'].qo)
#         self.quantize_conv_layers['qconvbnrelu4_3'].freeze(qi=self.quantize_conv_layers['qconvbnrelu4_2'].qo)
#         self.quantize_conv_layers['qconvbnrelu4_4'].freeze(qi=self.quantize_conv_layers['qconvbnrelu4_3'].qo)
#         self.quantize_conv_layers['qpool4'].freeze(self.quantize_conv_layers['qconvbnrelu4_4'].qo)

#         self.quantize_conv_layers['qconvbnrelu5_1'].freeze(qi=self.quantize_conv_layers['qconvbnrelu4_4'].qo)
#         #self.quantize_conv_layers['qconvbnrelu5_1'].freeze(qi=self.quantize_conv_layers['qpool4'].qo)
#         self.quantize_conv_layers['qconvbnrelu5_2'].freeze(qi=self.quantize_conv_layers['qconvbnrelu5_1'].qo)
#         self.quantize_conv_layers['qconvbnrelu5_3'].freeze(qi=self.quantize_conv_layers['qconvbnrelu5_2'].qo)
#         self.quantize_conv_layers['qconvbnrelu5_4'].freeze(qi=self.quantize_conv_layers['qconvbnrelu5_3'].qo)
#         self.quantize_conv_layers['qpool5'].freeze(self.quantize_conv_layers['qconvbnrelu5_4'].qo)

#         self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconvbnrelu5_4'].qo)
#         #self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qpool5'].qo)
#         self.quantize_fc_layers['qrelu1'].freeze(self.quantize_fc_layers['qfc1'].qo)
#         self.quantize_fc_layers['qfc2'].freeze(qi=self.quantize_fc_layers['qfc1'].qo)
#         #self.quantize_fc_layers['qfc2'].freeze(qi=self.quantize_fc_layers['qrelu1'].qo)
#         self.quantize_fc_layers['qrelu2'].freeze(self.quantize_fc_layers['qfc2'].qo)
#         self.quantize_fc_layers['qfc3'].freeze(qi=self.quantize_fc_layers['qfc2'].qo)
#         #self.quantize_fc_layers['qfc3'].freeze(qi=self.quantize_fc_layers['qrelu2'].qo)



#     def quantize_inference(self, x):
#         x = self.quantize_conv_layers['qconv1_1'].qi.quantize_tensor(x)

#         for s, layer in self.quantize_conv_layers.items():
#             x=layer.quantize_inference(x)

#         output = x.view(x.size()[0], -1)

#         for s, layer in self.quantize_fc_layers.items():

#             # elif (s == 'qrelu1') == True or (s == 'qrelu2') == True:
#             #     output = nn.Dropout(0.5)
#             # if (s == 'qdrop1')==True or (s=='qdrop2')==True:
#             #     output = F.dropout(output,0.45)
#             # else:
#             if ((s == 'qdrop1') == False ) and ((s == 'qdrop2') == False):
#                 output = layer.quantize_inference(output)
#             else:
#                 output = output


#         output = self.quantize_fc_layers['qfc3'].qo.dequantize_tensor(output)

#         out = F.softmax(output, dim=1)  # 这里应该用 Qsoftmax可能好些 之后改改
#         return out


class LeNet(nn.Module):
    # CONV FLOPs: 考虑bias:（2 * C_in * K_h * K_w )* H_out * W_out * C_out
    #             不考虑bias: （2 * C_in * K_h * K_w -1)* H_out * W_out * C_out
    # FCN FLOPs:  考虑bias: （2 * I ）* O
    #             不考虑bias: (2 * I - 1） * O
    def __init__(self, img_size=32, input_channel=3, num_class=10, n_exp=4, mode=1):
        super().__init__()
        self.conv_layers = nn.ModuleDict({
        # block1
        'conv1': nn.Conv2d(3,6,5),  # (2*3*5*5) * 32*32*6  (bias占其中的32*32*6)  6144/921600
        'bn1': nn.BatchNorm2d(6),
        'reluc1': nn.ReLU(),
        'pool1': nn.MaxPool2d(2,2),

        # block2
        'conv2': nn.Conv2d(6,16,5), # (2*6*5*5) * 16*16*16  (bias占其中的16*16*6) 1536/1228800
        'bn2': nn.BatchNorm2d(16),
        'reluc2': nn.ReLU(),
        'pool2': nn.MaxPool2d(2,2),
        })

        self.fc_layers = nn.ModuleDict({
            # classifier
            'fc1': nn.Linear(16*5*5,120), # (2*16*5*5)*120 (bias占其中的120)  120/96000
            'reluf1': nn.ReLU(),
            'fc2': nn.Linear(120,84), # (2*120)*84  (bias占其中的84)  84/2016
            'reluf2': nn.ReLU(),
            'fc3': nn.Linear(84, num_class)
        })
        self.mode = mode
        self.n_exp = n_exp

    def forward(self,x):

        for _,layer in self.conv_layers.items():
            x = layer(x)

        output = x.view(-1,16*5*5)
        for _,layer in self.fc_layers.items():
            output = layer(output)

        out = F.softmax(output,dim = 1)         # 这里不softmax也行 影响不大
        return out

    def quantize(self, num_bits=8):

        self.quantize_conv_layers=nn.ModuleDict({
            # qi=true: 前一层输出的结果是没有量化过的，需要量化。 maxpool和relu都不会影响INT和minmax，所以在这俩之后的层的pi是false
            #若前一层是conv，数据minmax被改变，则需要qi=true来量化
            # 'qconv1': QConv2d(self.conv_layers['conv1'], qi=True, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
            # 'qreluc1': QReLU(n_exp=self.n_exp, mode=self.mode),
            'qconvbnrelu1': QConvBNReLU(self.conv_layers['conv1'],self.conv_layers['bn1'],qi=True,qo=True,num_bits=num_bits,n_exp=self.n_exp,mode=self.mode),

            'qpool1': QMaxPooling2d(kernel_size=2,stride=2,padding=0, n_exp=self.n_exp, mode=self.mode),

            # 'qconv2': QConv2d(self.conv_layers['conv2'], qi=False, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
            # 'qreluc2': QReLU(n_exp=self.n_exp, mode=self.mode),
            'qconvbnrelu1': QConvBNReLU(self.conv_layers['conv2'],self.conv_layers['bn2'],qi=True,qo=True,num_bits=num_bits,n_exp=self.n_exp,mode=self.mode),
            'qpool2': QMaxPooling2d(kernel_size=2, stride=2, padding=0, n_exp=self.n_exp, mode=self.mode)
        })

        self.quantize_fc_layers = nn.ModuleDict({
            'qfc1': QLinear(self.fc_layers['fc1'],qi=False,qo=True,num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
            'qreluf1': QReLU(n_exp=self.n_exp, mode=self.mode),
            'qfc2': QLinear(self.fc_layers['fc2'],qi=False,qo=True,num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
            'qreluf2': QReLU(n_exp=self.n_exp, mode=self.mode),
            'qfc3':  QLinear(self.fc_layers['fc3'],qi=False,qo=True,num_bits=num_bits, n_exp=self.n_exp, mode=self.mode)
        })

    def quantize_forward(self, x):

        for _, layer in self.quantize_conv_layers.items():
            x = layer(x)

        output = x.view(-1,16*5*5)
        for s, layer in self.quantize_fc_layers.items():
            output = layer(output)

        out = F.softmax(output, dim=1)  # 这里不softmax也行 影响不大 算loss用
        return out


    def freeze(self):

        # self.quantize_conv_layers['qconv1'].freeze()
        # self.quantize_conv_layers['qreluc1'].freeze(self.quantize_conv_layers['qconv1'].qo)
        self.quantize_conv_layers['qconvbnrelu1'].freeze()
        #self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconv1'].qo)
        self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconvbnrelu1'].qo)

        # self.quantize_conv_layers['qconv2'].freeze(self.quantize_conv_layers['qconv1'].qo)
        # self.quantize_conv_layers['qreluc2'].freeze(self.quantize_conv_layers['qconv2'].qo)
        self.quantize_conv_layers['qconvbnrelu2'].freeze()
        # self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconv2'].qo)
        self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconvbnrelu2'].qo)

        # self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconv2'].qo)
        self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconvbnrelu2'].qo)

        self.quantize_fc_layers['qreluf1'].freeze(self.quantize_fc_layers['qfc1'].qo)
        self.quantize_fc_layers['qfc2'].freeze(qi=self.quantize_fc_layers['qfc1'].qo)
        self.quantize_fc_layers['qreluf2'].freeze(self.quantize_fc_layers['qfc2'].qo)
        self.quantize_fc_layers['qfc3'].freeze(qi=self.quantize_fc_layers['qfc2'].qo)

    def fakefreeze(self):
        # self.quantize_conv_layers['qconv1'].fakefreeze()
        # self.quantize_conv_layers['qreluc1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
        # self.quantize_conv_layers['qpool1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
        self.quantize_conv_layers['qconvbnrelu1'].fakefreeze()
        self.quantize_conv_layers['qpool1'].fakefreeze(self.quantize_conv_layers['qconvbnrelu1'].qo)


        # self.quantize_conv_layers['qconv2'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
        # self.quantize_conv_layers['qreluc2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo)
        # self.quantize_conv_layers['qpool2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo)

        self.quantize_conv_layers['qconvbnrelu2'].fakefreeze()
        self.quantize_conv_layers['qpool2'].fakefreeze(self.quantize_conv_layers['qconvbnrelu2'].qo)

        # self.quantize_fc_layers['qfc1'].fakefreeze(qi=self.quantize_conv_layers['qconv2'].qo)
        self.quantize_fc_layers['qfc1'].fakefreeze(qi=self.quantize_conv_layers['qconvbnrelu2'].qo)
        self.quantize_fc_layers['qreluf1'].fakefreeze(self.quantize_fc_layers['qfc1'].qo)
        self.quantize_fc_layers['qfc2'].fakefreeze(qi=self.quantize_fc_layers['qfc1'].qo)
        self.quantize_fc_layers['qreluf2'].fakefreeze(self.quantize_fc_layers['qfc2'].qo)
        self.quantize_fc_layers['qfc3'].fakefreeze(qi=self.quantize_fc_layers['qfc2'].qo)

    def quantize_inference(self, x):
        # x = self.quantize_conv_layers['qconv1'].qi.quantize_tensor(x,  self.mode)
        x = self.quantize_conv_layers['qconvbnrelu1'].qi.quantize_tensor(x,  self.mode)

        for s, layer in self.quantize_conv_layers.items():
            x = layer.quantize_inference(x)

        output = x.view( -1,16*5*5)

        for s, layer in self.quantize_fc_layers.items():
            output = layer.quantize_inference(output)

        # 只有mode1需要出现范围映射，将量化后的数据恢复到原数据相似的范围，PoT无需，其自带恢复性
        if self.mode == 1:
            output = self.quantize_fc_layers['qfc3'].qo.dequantize_tensor(output, self.mode)

        out = F.softmax(output, dim=1)  # 这里应该用 Qsoftmax可能好些 之后改改
        return out


class Net(nn.Module):

    def __init__(self, num_channels=1):
        super(Net, self).__init__()
        # self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        # self.conv2 = nn.Conv2d(40, 40, 3, 1, groups=20)
        # self.fc = nn.Linear(5*5*40, 10)



    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5*5*40)  # 重新定义矩阵的形状
        x = self.fc(x)
        return x

    # 对模型进行初步的量化 （此处还没开始训练量化模型） 对于量化参数的确定是一个预先的过程（之后对量化模型的训练是fine tune）
    def quantize(self, num_bits=8):
        # 这里仅第一个qi=True，因为在forward的时候除了pool和relu外，每层最后都会根据qo调整一下x，完成量化再恢复的工作，所以x实际上是保持着最新版本的量化再恢复，只有weight需要在各个层不断调整量化再恢复的情况。
        self.qconv1 = QConv2d(self.conv1, qi=True, qo=True, num_bits=num_bits)
        self.qrelu1 = QReLU()
        self.qmaxpool2d_1 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qconv2 = QConv2d(self.conv2, qi=False, qo=True, num_bits=num_bits) #qi=False的含义是无需在这层换量化参数scale,zeropoint
        self.qrelu2 = QReLU()
        self.qmaxpool2d_2 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qfc = QLinear(self.fc, qi=False, qo=True, num_bits=num_bits)

    # 训练量化模型时的forward函数
    def quantize_forward(self, x):
        x = self.qconv1(x)
        x = self.qrelu1(x)
        x = self.qmaxpool2d_1(x)
        x = self.qconv2(x)
        x = self.qrelu2(x)
        x = self.qmaxpool2d_2(x)
        x = x.view(-1, 5*5*40)
        x = self.qfc(x)
        return x

  # 量化模型训练完之后把参数固定住 目的是规定当前层在推断的时候使用什么样的量化参数，量化参数决定了映射，量化后恢复的结果
  # 考虑各个层最后对qo的更新，maxpool，relu，drop不会更新qo，conv会，因此只有出现conv后需要改下Qpram，其他的继承下去就行了（层与层输出与输入相连的x自带了这种继承关系，而Qpram的需要看情况（即，x是否可能有max，min范围的突破）来决定是否要更新）
    def freeze(self):
        self.qconv1.freeze()
        self.qrelu1.freeze(self.qconv1.qo)  # 就是作为qi 带conv的层后面需要用新的 (qo总是会在训练过程被更新的，因为有不一样的x和模型参数,是在Q... layer的forward过程中不断更新的,min,max是一个全局的统计效果，考虑到是fine tunning 一开始的minmax也不会太离谱)
        self.qmaxpool2d_1.freeze(self.qconv1.qo)
        self.qconv2.freeze(qi=self.qconv1.qo)
        self.qrelu2.freeze(self.qconv2.qo)  # relu和maxpool对Qpram不具备改变能力(因为min，max的统计是全局性质的，relu和poolmax对min和max都没有影响)
        self.qmaxpool2d_2.freeze(self.qconv2.qo)
        self.qfc.freeze(qi=self.qconv2.qo)

   # 固定住量化模型参数后的推理  FP32入 过程中量化  FP32出
    def quantize_inference(self, x):
        qx = self.qconv1.qi.quantize_tensor(x)
        qx = self.qconv1.quantize_inference(qx)
        qx = self.qrelu1.quantize_inference(qx)
        qx = self.qmaxpool2d_1.quantize_inference(qx)
        qx = self.qconv2.quantize_inference(qx)
        qx = self.qrelu2.quantize_inference(qx)
        qx = self.qmaxpool2d_2.quantize_inference(qx)
        qx = qx.view(-1, 5*5*40)
        qx = self.qfc.quantize_inference(qx)
        out = self.qfc.qo.dequantize_tensor(qx)
        return out


class NetBN(nn.Module):

    def __init__(self, num_channels=1):
        super(NetBN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.bn1 = nn.BatchNorm2d(40)
        self.conv2 = nn.Conv2d(40, 40, 3, 1)
        self.bn2 = nn.BatchNorm2d(40)
        self.fc = nn.Linear(5 * 5 * 40, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5 * 5 * 40)
        x = self.fc(x)
        return x

    def quantize(self, num_bits=8):
        self.qconv1 = QConvBNReLU(self.conv1, self.bn1, qi=True, qo=True, num_bits=num_bits)
        self.qmaxpool2d_1 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qconv2 = QConvBNReLU(self.conv2, self.bn2, qi=False, qo=True, num_bits=num_bits)
        self.qmaxpool2d_2 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qfc = QLinear(self.fc, qi=False, qo=True, num_bits=num_bits)

    def quantize_forward(self, x):
        x = self.qconv1(x)
        x = self.qmaxpool2d_1(x)
        x = self.qconv2(x)
        x = self.qmaxpool2d_2(x)
        x = x.view(-1, 5*5*40)
        x = self.qfc(x)
        return x

    def freeze(self):
        self.qconv1.freeze()
        self.qmaxpool2d_1.freeze(self.qconv1.qo)
        self.qconv2.freeze(qi=self.qconv1.qo)  # 因为maxpool不会改变min，max
        self.qmaxpool2d_2.freeze(self.qconv2.qo)
        self.qfc.freeze(qi=self.qconv2.qo)  # 因为maxpool不会改变min，max

    def quantize_inference(self, x):
        qx = self.qconv1.qi.quantize_tensor(x)
        qx = self.qconv1.quantize_inference(qx)
        qx = self.qmaxpool2d_1.quantize_inference(qx)
        qx = self.qconv2.quantize_inference(qx)
        qx = self.qmaxpool2d_2.quantize_inference(qx)
        qx = qx.view(-1, 5*5*40)

        qx = self.qfc.quantize_inference(qx)
        
        out = self.qfc.qo.dequantize_tensor(qx)   # INT -> FP
        return out
