import torch
import torch.nn as nn
import torch.nn.functional as F

from cfg import *
from module import *
import module

def make_layers(model,cfg_table):
    for i in range(len(cfg_table)):
        cfg = cfg_table[i]
        if cfg[0] == 'Inc':
            make_inc_layers(model,cfg[1])
        elif cfg[0] == 'C':
            name = 'conv%d'%i
            layer = nn.Conv2d(cfg[3],cfg[4],kernel_size=cfg[5],stride=cfg[6],padding=cfg[7],bias=cfg[8])
            model.add_module(name,layer)
            if 'B' in cfg[1]:
                name = 'bn%d'%i
                layer = nn.BatchNorm2d(cfg[4])
                model.add_module(name,layer)
                if 'R' in cfg[1]:
                    name = 'relu%d'%i
                    layer = nn.ReLU(True)
                    model.add_module(name,layer)
        elif cfg[0] == 'R':
            name = 'relu%d'%i
            layer = nn.ReLU(True)
            model.add_module(name,layer)
        elif cfg[0] == 'MP':
            name = 'pool%d'%i
            layer = nn.MaxPool2d(kernel_size=cfg[1],stride=cfg[2],padding=cfg[3])
            model.add_module(name,layer)
        elif cfg[0] == 'AAP':
            name = 'aap%d'%i
            layer = nn.AdaptiveAvgPool2d(cfg[1])
            model.add_module(name,layer)
        elif cfg[0] == 'FC':
            name = 'fc%d'%i
            layer = nn.Linear(cfg[1],cfg[2],bias=cfg[3])
            model.add_module(name,layer)
        elif cfg[0] == 'D':
            name = 'drop%d'%i
            layer = nn.Dropout(cfg[1])
            model.add_module(name,layer)
        

def model_forward(model,cfg_table,x):
    for i in range(len(cfg_table)):
        cfg = cfg_table[i]
        if cfg[0] == 'Inc':
            x = inc_forward(model,cfg[1],x)
        elif cfg[0] == 'C':
            name = 'conv%d'%i
            layer = getattr(model,name)
            x = layer(x)
            if 'B' in cfg[1]:
                name = 'bn%d'%i
                layer = getattr(model,name)
                x = layer(x)
                if 'R' in cfg[1]:
                    name = 'relu%d'%i
                    layer = getattr(model,name)
                    x = layer(x)
        elif cfg[0] == 'R':
            name = 'relu%d'%i
            layer = getattr(model,name)
            x = layer(x)
        elif cfg[0] == 'MP':
            name = 'pool%d'%i
            layer = getattr(model,name)
            x = layer(x)
        elif cfg[0] == 'AAP':
            name = 'aap%d'%i
            layer = getattr(model,name)
            x = layer(x)
        elif cfg[0] == 'FC':
            name = 'fc%d'%i
            layer = getattr(model,name)
            x = layer(x)
        elif cfg[0] == 'D':
            name = 'drop%d'%i
            layer = getattr(model,name)
            x = layer(x)
        elif cfg[0] == 'FT':
            x = torch.flatten(x, start_dim=1)

    return x

def model_quantize(model,cfg_table,quant_type,num_bits,e_bits):
    for i in range(len(cfg_table)):
        cfg = cfg_table[i]
        if cfg[0] == 'Inc':
            inc_quantize(model,cfg[1],quant_type,num_bits,e_bits)
        elif cfg[0] == 'C':
            conv_name = 'conv%d'%i
            conv_layer = getattr(model,conv_name)
            qname = 'q_'+conv_name
            if 'B' in cfg[1]:
                bn_name = 'bn%d'%i
                bn_layer = getattr(model,bn_name)
                if 'R' in cfg[1]:
                    qlayer = QConvBNReLU(quant_type,conv_layer,bn_layer,qi=cfg[2],num_bits=num_bits,e_bits=e_bits)
                else:
                    qlayer = QConvBN(quant_type,conv_layer,bn_layer,qi=cfg[2],num_bits=num_bits,e_bits=e_bits)
            else:
                qlayer = QConv2d(quant_type,conv_layer,qi=cfg[2],num_bits=num_bits,e_bits=e_bits)
            model.add_module(qname,qlayer)
        elif cfg[0] == 'R':
            name = 'relu%d'%i
            qname = 'q_'+name
            qlayer = QReLU(quant_type,num_bits=num_bits,e_bits=e_bits)
            model.add_module(qname,qlayer)
        elif cfg[0] == 'MP':
            name = 'pool%d'%i
            qname = 'q_'+name
            qlayer = QMaxPooling2d(quant_type,kernel_size=cfg[1],stride=cfg[2],padding=cfg[3],num_bits=num_bits,e_bits=e_bits)
            model.add_module(qname,qlayer)
        elif cfg[0] == 'AAP':
            name = 'aap%d'%i
            qname = 'q_'+name
            qlayer = QAdaptiveAvgPool2d(quant_type,output_size=cfg[1],num_bits=num_bits,e_bits=e_bits)
            model.add_module(qname,qlayer)
        elif cfg[0] == 'FC':
            name = 'fc%d'%i 
            layer = getattr(model,name)
            qname = 'q_'+name
            qlayer = QLinear(quant_type,layer,num_bits=num_bits,e_bits=e_bits)
            model.add_module(qname,qlayer)

# 增加了func='fakefreeze'
def model_utils(model,cfg_table,func,x=None):
    last_qo = None
    for i in range(len(cfg_table)):
        cfg = cfg_table[i]
        if cfg[0] == 'Inc':
            x,last_qo = inc_utils(model,cfg[1],func,x,last_qo)
        elif cfg[0] == 'C':
            qname = 'q_conv%d'%i
            qlayer = getattr(model,qname)
            if func == 'forward':
                x = qlayer(x)
            elif func == 'inference':
                # cfg[2]为True表示起始层，需要量化
                if cfg[2]:
                    x = qlayer.qi.quantize_tensor(x)
                x = qlayer.quantize_inference(x)
            elif func == 'freeze':
                qlayer.freeze(last_qo)
            elif func == 'fakefreeze':
                qlayer.fakefreeze()
            last_qo = qlayer.qo
        elif cfg[0] == 'R':
            qname = 'q_relu%d'%i
            qlayer = getattr(model,qname)
            if func == 'forward':
                x = qlayer(x)
            elif func == 'inference':
                x = qlayer.quantize_inference(x)
            elif func == 'freeze':
                qlayer.freeze(last_qo)
        elif cfg[0] == 'MP':
            qname = 'q_pool%d'%i
            qlayer = getattr(model,qname)
            if func == 'forward':
                x = qlayer(x)
            elif func == 'inference':
                x = qlayer.quantize_inference(x)
            elif func == 'freeze':
                qlayer.freeze(last_qo)
        elif cfg[0] == 'AAP':
            qname = 'q_aap%d'%i
            qlayer = getattr(model,qname)
            if func == 'forward':
                x = qlayer(x)
            elif func == 'inference':
                x = qlayer.quantize_inference(x)
            elif func == 'freeze':
                qlayer.freeze(last_qo)
            last_qo = qlayer.qo
        elif cfg[0] == 'FC':
            qname = 'q_fc%d'%i
            qlayer = getattr(model,qname)
            if func == 'forward':
                x = qlayer(x)
            elif func == 'inference':
                x = qlayer.quantize_inference(x)
            elif func == 'freeze':
                qlayer.freeze(last_qo)
            elif func == 'fakefreeze':
                qlayer.fakefreeze()
            last_qo = qlayer.qo
        elif cfg[0] == 'D':
            if func == 'forward':
                name = 'drop%d'%i
                layer = getattr(model,name)
                x = layer(x)
        elif cfg[0] == 'FT':
            if func == 'inference' or func == 'forward':
                x = torch.flatten(x,start_dim=1)
    
    if func == 'inference':
        x = last_qo.dequantize_tensor(x)

    return x

def make_inc_layers(model,inc_idx):
    inc_name = 'inc%d'%inc_idx
    ch = inc_ch_table[inc_idx]
    for i in range(4): # branch
        prefix = inc_name+'_br%d_'%i
        for j in range(len(inc_cfg_table[i])):
            cfg = inc_cfg_table[i][j]
            if cfg[0] == 'MP':
                name = prefix+'pool%d'%j
                layer =nn.MaxPool2d(kernel_size=cfg[1],stride=cfg[2],padding=cfg[3])
                model.add_module(name,layer)
            elif cfg[0] == 'R':
                name=prefix+'relu%d'%j
                layer=nn.ReLU(True)
                model.add_module(name,layer)
            elif cfg[0] == 'C':
                name=prefix+'conv%d'%j
                layer=nn.Conv2d(ch[cfg[1]],ch[cfg[2]],kernel_size=cfg[3],stride=cfg[4],padding=cfg[5],bias=False)
                model.add_module(name,layer)
                name=prefix+'bn%d'%j
                layer=nn.BatchNorm2d(ch[cfg[2]])
                model.add_module(name,layer)
                name=prefix+'relu%d'%j
                layer=nn.ReLU(True)
                model.add_module(name,layer)

def inc_forward(model,inc_idx,x=None):
    inc_name = 'inc%d'%inc_idx
    outs = []
    for i in range(4):
        prefix = inc_name+'_br%d_'%i
        tmp = x
        for j in range(len(inc_cfg_table[i])):
            cfg = inc_cfg_table[i][j]
            if cfg[0] == 'MP':
                name = prefix+'pool%d'%j
                layer = getattr(model,name)
                tmp = layer(tmp)
            elif cfg[0] == 'R':
                name=prefix+'relu%d'%j
                layer = getattr(model,name)
                tmp = layer(tmp)
            else:
                name=prefix+'conv%d'%j
                layer = getattr(model,name)
                tmp = layer(tmp)
                name=prefix+'bn%d'%j
                layer = getattr(model,name)
                tmp = layer(tmp)
                name=prefix+'relu%d'%j
                layer = getattr(model,name)
                tmp = layer(tmp)
        outs.append(tmp)
    
    out = torch.cat(outs,1)
    return out

def inc_quantize(model,inc_idx,quant_type,num_bits,e_bits):
    inc_name = 'inc%d'%inc_idx
    for i in range(4):
        prefix = inc_name+'_br%d_'%i
        for j in range(len(inc_cfg_table[i])):
            cfg = inc_cfg_table[i][j]
            if cfg[0] == 'MP':
                name = prefix+'pool%d'%j
                qname = 'q_'+name
                qlayer = QMaxPooling2d(quant_type,kernel_size=cfg[1],stride=cfg[2],padding=cfg[3],num_bits=num_bits,e_bits=e_bits)
                model.add_module(qname,qlayer)
            elif cfg[0] == 'R':
                name = prefix+'relu%d'%j
                qname = 'q_'+name
                qlayer = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
                model.add_module(qname,qlayer)
            else:
                conv_name=prefix+'conv%d'%j
                conv_layer=getattr(model,conv_name)
                bn_name=prefix+'bn%d'%j
                bn_layer=getattr(model,bn_name)
                qname='q_'+conv_name
                qlayer=QConvBNReLU(quant_type, conv_layer, bn_layer, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
                model.add_module(qname,qlayer)
    qname = 'q_'+inc_name+'_concat'
    qlayer = QConcat(quant_type,4,qi_array=False,qo=True,num_bits=num_bits,e_bits=e_bits)
    model.add_module(qname,qlayer)
    
def inc_utils(model,inc_idx,func,x=None,qo=None):
    inc_name = 'inc%d'%inc_idx
    outs=[]
    qos=[]
    for i in range(4): 
        qprefix = 'q_'+inc_name+'_br%d_'%i
        tmp = x
        last_qo = qo
        for j in range(len(inc_cfg_table[i])):
            cfg = inc_cfg_table[i][j]
            if cfg[0] == 'MP':
                qname = qprefix+'pool%d'%j
                qlayer = getattr(model,qname)
                if func == 'forward':
                    tmp = qlayer(tmp)
                elif func == 'inference':
                    tmp = qlayer.quantize_inference(tmp)
                elif func == 'freeze':
                    qlayer.freeze(last_qo)
            elif cfg[0] == 'R':
                qname = qprefix+'relu%d'%j
                qlayer = getattr(model,qname)
                if func == 'forward':
                    tmp = qlayer(tmp)
                elif func == 'inference':
                    tmp = qlayer.quantize_inference(tmp)
                elif func == 'freeze':
                    qlayer.freeze(last_qo)
            elif cfg[0] == 'C':
                qname = qprefix+'conv%d'%j
                qlayer = getattr(model,qname)
                if func == 'forward':
                    tmp = qlayer(tmp)
                elif func == 'inference':
                    tmp = qlayer.quantize_inference(tmp)
                elif func == 'freeze':
                    qlayer.freeze(last_qo)
                elif func == 'fakefreeze':
                    qlayer.fakefreeze()
                last_qo = qlayer.qo
        
        outs.append(tmp)
        qos.append(last_qo)

    qname = 'q_'+inc_name+'_concat'
    qlayer = getattr(model,qname)
    out = None
    if func == 'forward':
        out = qlayer(outs)
    elif func == 'inference':
        out = qlayer.quantize_inference(outs)
    elif func == 'freeze':
        qlayer.freeze(qos)
    last_qo = qlayer.qo
    return out,last_qo


class Model(nn.Module):
    def __init__(self,model_name,num_classes=10):
        super(Model, self).__init__()
        self.cfg_table = model_cfg_table[model_name]
        make_layers(self,self.cfg_table)
    
    def forward(self,x):
        x = model_forward(self,self.cfg_table,x)
        return x

    def quantize(self, quant_type, num_bits=8, e_bits=3):
        model_quantize(self,self.cfg_table,quant_type,num_bits,e_bits)
    
    def quantize_forward(self,x):
        return model_utils(self,self.cfg_table,func='forward',x=x)

    def freeze(self):
        model_utils(self,self.cfg_table,func='freeze')
    
    def quantize_inference(self,x):
        return model_utils(self,self.cfg_table,func='inference',x=x)

    def fakefreeze(self):
        model_utils(self,self.cfg_table,func='fakefreeze')

# if __name__ == "__main__":
#     model = Inception_BN()
#     model.quantize('INT',8,3)
#     print(model.named_modules)
#     print('-------')
#     print(model.named_parameters)
#     print(len(model.conv0.named_parameters()))