import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from get_weight import *
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import CosineAnnealingLR
from resnet import *
from torchvision.transforms import transforms
# import models

import time
import os


import argparse
# 定义模型


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 6 * 6, 512)
        self.fc2 = nn.Linear(512, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def train(model, optimizer, criterion, train_loader, device):
    
    model.train()

    running_loss = 0.0

    flag = 0
    cnt = 0

    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        histo, grads = (get_model_histogram(model))
        if flag == 0:
            flag = 1
            grads_sum = grads
        else:
            for k,v in grads_sum.items():
                grads_sum[k] += grads[k]


        optimizer.step()

        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)

    for k, v in grads_sum.items():
        grads_sum[k] = v / len(train_loader)

    return train_loss,grads_sum


def evaluate(model, criterion, test_loader, device):
    model.eval()

    correct, total = 0, 0

    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy


def get_children(model: torch.nn.Module):
    # get children form model!
    # 为了后续也能够更新参数，需要用nn.ModuleList来承载

    children = nn.ModuleList(model.children())
    # print(children)
    # 方便对其中的module进行后续的更新 
    flatt_children = nn.ModuleList()  

    # children = list(model.children())
    # flatt_children = nn.ModuleList()  
    # flatt_children = []
    if len(children) == 0:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))

    # print(flatt_children)
    return flatt_children


if __name__ == "__main__":
    # torch.cuda.empty_cache()

    parser = argparse.ArgumentParser(description='PyTorch FP32 Training')
    parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
    parser.add_argument('-e','--epochs', default=100, type=int, metavar='EPOCHS', help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
    parser.add_argument('-j','--workers', default=4, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
    parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('-wd','--weight_decay',default=0.0001,type=float,metavar='WD',help='lr schduler weight decay',dest='wd')
    parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set')
    

    # models = ['resnet18', 'resnet50', 'resnet152','resnet18']



    # 训练参数
    args = parser.parse_args()

    num_epochs = args.epochs
    print(num_epochs)
    batch_size = args.batch_size
    print(batch_size)
    num_workers = args.workers
    lr = args.lr
    weight_decay = args.wd

    best_acc = float("-inf")

    start_time = time.time()

    # 模型、损失函数和优化器
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 加入设备选择
    print(device)


    # model = Net().to(device)  # 将模型移动到 device 上
    # model = resnet18().to(device)
    # model = models.__dict__[args.model]().to(device)

    # t = torch.cuda.get_device_properties(0).total_memory
    # r = torch.cuda.memory_reserved(0)
    # a = torch.cuda.memory_allocated(0)
    # f = r-a  # free memory
    # print(f"Total memory: {t}")
    # print(f"Reserved memory: {r}")
    # print(f"Allocated memory: {a}")
    # print(f"Free memory: {f}")

    if args.model == 'resnet18' :
        model = resnet18().to(device)
    elif args.model == 'resnet50' :
        model = resnet50().to(device)
    elif args.model == 'resnet152' :
        model = resnet152().to(device)


    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # optimizer = optim.AdaBound(model.parameters(), lr=lr,
    #                        weight_decay=weight_decay, final_lr=0.001*lr)
    print("ok!")

# 数据并行
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs")
        model = nn.DataParallel(model)
        
    

  

    # 加载数据

    train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./project/p/data', train=True, download=False,
                     transform=transforms.Compose([
                        transforms.RandomCrop(32, padding=2),
                        transforms.RandomHorizontalFlip(),
                         transforms.ToTensor(),
                         transforms.Normalize(
                             (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                     ])),
     batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )

    test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./project/p/data', train=False, download=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])),
    batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    # 学习率调度器
    # lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    lr_scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

    # TensorBoard

    # WARN
    # writer = SummaryWriter(log_dir='./project/p/models_log/trail/full_log')
    writer = SummaryWriter(log_dir='./project/p/models_log/' + args.model  +  '/full_log')

    # Early Stopping 参数
    patience = 30
    count = 0
    # WARN
    # save_dir = './project/p/ckpt/trail'
    save_dir = './project/p/ckpt/' + args.model
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir, mode=0o777)
        os.chmod(save_dir, mode=0o777)

    # checkpoint_dir = './project/p/checkpoint/cifar-10_trail_model'
    checkpoint_dir = './project/p/checkpoint/cifar-10_' + args.model
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir, mode=0o777)
        os.chmod(checkpoint_dir, mode=0o777)


    # 训练循环

    if args.test == True:
       
       model.load_state_dict(torch.load(save_dir+'/' + args.model + '.pt'))
       acc = evaluate(model, criterion, test_loader, device=device)
       print(f"test accuracy: {acc:.2f}%")


       for name, module in model.named_modules():
           print(f"{name}: {module}\n")
            
        

       print('========================================================') 
       print('========================================================')      
 
       model.quantize()
       for name , layer in model.quantize_layers.items():
            print(f"Layer {name}: {layer} ")  # 足够遍历了

   


                
            
           
       


    else:
        for epoch in range(num_epochs):
        # 训练模型并记录 loss
            train_loss,grads_sum = train(model, optimizer, criterion,
                       train_loader, device=device)
            writer.add_scalar("Training Loss", train_loss, epoch + 1)

            # 评估模型并记录 accuracy
            if (epoch + 1) % 5 == 0:
                acc = evaluate(model, criterion, test_loader, device=device)
                writer.add_scalar("Validation Accuracy", acc, epoch + 1)


                checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'grads': grads_sum,
                'accuracy':acc
                }
                # for name, param in model.named_parameters():
                #     writer.add_histogram(tag=name + '_grad', values=param.grad, global_step=epoch)
                #     writer.add_histogram(tag=name + '_data', values=param.data, global_step=epoch)
                for name, param in grads_sum.items():
                    # 此处的grad是累加值吧 不是平均值
                    writer.add_histogram(tag=name + '_grad', values=param, global_step=epoch)
                    # 取这个epoch最后一个batch算完之后的weight
                for name, param in model.named_parameters():
                    writer.add_histogram(tag=name + '_data', values=param.data, global_step=epoch)

                # WARN
                # torch.save(checkpoint, checkpoint_dir + '/ckpt_cifar-10_trail_model%s.pt' % (str(epoch+1)))
                torch.save(checkpoint, checkpoint_dir + '/ckpt_cifar-10_' + args.model + '_%s.pt' % (str(epoch+1)))

            # 存储最好的模型
                if acc > best_acc:
                    best_acc = acc
                    count = 0
                    # WARN
                    # torch.save(model.state_dict(), save_dir+'/model_trail.pt')
                    torch.save(model.state_dict(), save_dir+'/'  + args.model  +  '.pt')
                else:
                    count += 1

                print(
                    f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.5f}, Val Acc: {acc:.2f}%")

            # 判断是否需要 early stopping
            if count == patience:
                print(f"No improvement after {patience} epochs. Early stop!")
                break

            # 更新学习率
            lr_scheduler.step()

        # 训练用时和最佳验证集准确率
        print(f"Training took {(time.time() - start_time) / 60:.2f} minutes")
        print(f"Best validation accuracy: {best_acc:.2f}%")

        # 加载并测试最佳模型
        # model.load_state_dict(torch.load("best_model.pth"))
        # model.to(device)

        # test_acc = evaluate(model, criterion, test_loader, device="cuda")
        # print(f"Test Accuracy: {test_acc:.2f}%")

        # 关闭 TensorBoard 写入器
        writer.close()


