from model import *

import torch
from ptflops import get_model_complexity_info
import argparse

def get_children(model: torch.nn.Module):
    # get children form model!
    # 为了后续也能够更新参数，需要用nn.ModuleList来承载

    # children = nn.ModuleList(model.children())
    # print(children)
    # 方便对其中的module进行后续的更新 
    # flatt_children = nn.ModuleList()  

    children = list(model.children())
    # flatt_children = nn.ModuleList()  
    flatt_children = []
    if len(children) == 0:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))

    # print(flatt_children)
    return flatt_children

    # 定义获取不包含wrapper的所有子模块的函数
def get_all_child_modules(module):
    for name, child in module.named_children():
        if isinstance(child, nn.Sequential):
            yield from get_all_child_modules(child)
        elif len(list(child.children())) > 0:
            yield from child.children()
        else:
            yield child

def filter_fn(module, n_inp, outp_shape):
    # if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.Linear,torch.nn.AdaptiveAvgPool2d)):
    if 'conv' in module or 'bn' in module or 'fc' in module or 'avg' in module or 'relu' in module:
        return True
    return False

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Model Analysis --- params & flops')
    parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
    args = parser.parse_args()
    if args.model == 'ResNet18':
        model = resnet18()
    elif args.model == 'ResNet50':
        model = resnet50()
    elif args.model == 'ResNet152':
        model = resnet152()

    full_file = 'ckpt/cifar10_' + args.model + '.pt'
    model.load_state_dict(torch.load(full_file))
    # flat = get_children(model)
    # print(flat)
    # flat = get_children(model)
    # new_model = nn.Sequential(*flat)
    flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)

