# -*- coding: utf-8 -*-
from torch.serialization import load
from model import *

import argparse
import torch
import sys
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter


 # 为了得到PTQ的权重数据的伪量化版 （先quantize再dequantize，与full precision的权重数据分布相似，便于用wasserstein距离求相似度）

def direct_quantize(model, test_loader, device):
    for i, (data, target) in enumerate(test_loader, 1):
        data, target = data.to(device), target.to(device)
        output = model.quantize_forward(data)  # 这里会依次调用model中各个层的forward，则会update qw
        if i % 5000 == 0:
            break
    print('direct quantization finish')


def full_inference(model, test_loader, device):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('\nTest set: Full Model Accuracy: {:.4f}%\n'.format(100. * correct / len(test_loader.dataset)))


def quantize_inference(model, test_loader, device):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        data, target = data.to(device), target.to(device)
        output = model.quantize_inference(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    acc = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Quant Model Accuracy: {:.4f}%\n'.format(acc))
    return acc


if __name__ == "__main__":
    d1 = sys.argv[1]
    batch_size = 32
    using_bn = True
    load_quant_model_file = None
    # load_model_file = None
    net = 'LeNet'  # 1:
    acc = 0

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('data', train=True, download=True,
                         transform=transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                         ])),
        batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),
        batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
    )

    if using_bn:
        model = LeNet().to(device)
        # 生成梯度分布图的时候是从0开始训练的
        model.load_state_dict(torch.load('ckpt/cifar-10_lenet_bn.pt', map_location='cpu'))
    # else:
    #     model = Net()
    #     model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
    #     save_file = "ckpt/mnist_cnn_ptq.pt"
    # model.to(device)
    model.eval()
    full_inference(model, test_loader, device)



    num_bits = int(d1)
    model.quantize(num_bits=num_bits)
    model.eval()
    print('Quantization bit: %d' % num_bits)

    dir_name = './ptq_fake_log/' + 'quant_bit_' + str(d1) + '_log'
    if not os.path.isdir(dir_name):
        os.makedirs(dir_name, mode=0o777)
        os.chmod(dir_name, mode=0o777)

    qwriter = SummaryWriter(log_dir=dir_name)
    # for name, param in model.named_parameters():
    #     qwriter.add_histogram(tag=name + '_data', values=param.data)

    if load_quant_model_file is not None:
        model.load_state_dict(torch.load(load_quant_model_file))
        print("Successfully load quantized model %s" % load_quant_model_file)

    direct_quantize(model, train_loader, device)

    model.fakefreeze()  # 权重量化

    for name, param in model.named_parameters():
        qwriter.add_histogram(tag=name + '_data', values=param.data)


    dir_name ='ckpt/ptq_fakefreeze'
    if not os.path.isdir(dir_name):
        os.makedirs(dir_name, mode=0o777)
        os.chmod(dir_name, mode=0o777)

    save_file = 'ckpt/ptq_fakefreeze/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt'
    torch.save(model.state_dict(), save_file)







