# -*- coding: utf-8 -*-
import numpy
import numpy as np
import torch
import sys
from collections import OrderedDict
import scipy.stats
import pandas as pd
import os

# 整体思路： 本函数实现的是关于bit的，在不同epoch节点（5, 10, ...） 的梯度分布相似度计算 （考虑到是不同epoch节点，则需要在这一段epoch内取平均相似度？）
# 外界调用： 会用不同的bit分别调用该函数
# csv中每行记录的是该bit量化情况下，不同epoch节点的平均加权梯度分布相似度
#
d1 = sys.argv[1]  # bit
d2 = sys.argv[2]  # mode
d3 = sys.argv[3]  # n_exp
# d2 = sys.argv[2] # epoch
# d1=4
# d2=5
tag = 0

dirpath = './project/p/qat_analysis_data/mode' + str(d2)

if not os.path.isdir(dirpath):
    os.makedirs(dirpath, mode=0o777)
    os.chmod(dirpath, mode=0o777)


# if int(d2) == 1:
#     csvpath = './project/p/qat_analysis_data/wasserstein_distance.csv'
# else:
if int(d2) != 3:
    csvpath = './project/p/qat_analysis_data/mode' + str(d2) + '/wasserstein_distance.csv'
else:
    csvpath = './project/p/qat_analysis_data/mode' + str(d2) + '/wasserstein_distance_' + str(d3) + '.csv'



# if os.path.exists("./qat_analysis_data/wasserstein_distance.csv"):
if os.path.exists(csvpath):
    tag = 1

if tag == 0:  # 还没有csv
    df = pd.DataFrame()
else:  # 已有csv
    # df = pd.read_csv("./qat_analysis_data/wasserstein_distance.csv", index_col=0)
    df = pd.read_csv(csvpath, index_col=0)
    df2 = pd.DataFrame()


# CNN FLOPs = Cout * Hout * Wout * (2 * Cin * K * K )  是考虑bias  否则-1
# FCN FLOPs = Cout * Cin  是考虑bias 否则-1


# 把相关的relu，pool也考虑进去了
# MAdd
# weight0 =np.array( [ 705600.0+4704.0+ 3528.0 ,  480000.0+ 1600.0 + 1200.0 , 95880.0 + 120.0,
#            20076.0 + 84.0 , 1670.0 ])
# weight1=np.array([705,600.0 ,  480,000.0,+  95,880.0 ,
#            20,076.0  , 1,670.0 ])

# flops
weight_f0= np.array([357504+4704+4704, 241600+1600+1600,48000+120,10080+84,840])
weight_f1=np.array([357504, 241600,48000,10080,840])

# 对不同的epoch节点
for epoch in [5, 10, 15, 20, 25, 30]:
    total_quan_list = []
    total_base_list = []
    summary_quan_dict = OrderedDict()
    summary_base_dict = OrderedDict()

    flag = 0
    result = 0
    # 最外层：不同epoch的字典 内层：各个网络层的grads
    # 遍历epoch节点内的epoch，收集梯度信息
    for i in range(epoch):
        if int(d2) == 1:
            total_quan_list.append(torch.load(
                './project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/' + str(d1) + '/ckpt_cifar-10_lenet_bn_quant_' + str(
                    i + 1) + '.pth'))
        elif int(d2) == 2:
             total_quan_list.append(torch.load(
                 './project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d2) + '/' + str(d1)+ '/ckpt_cifar-10_lenet_bn_quant_' + str(
                       epoch) + '.pth'))
        
        else:
            total_quan_list.append(torch.load(
                 './project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d2) + '_' + str(d3) + '/' + str(d1)+ '/ckpt_cifar-10_lenet_bn_quant_' + str(
                       epoch) + '.pth'))


        # total_quan_list.append(torch.load('checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(d2) + '.pth'))
        # full的数据数不够
        total_base_list.append(
            torch.load('./project/p/checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(i + 1) + '.pth'))
        for k, _ in total_base_list[i]['grads'].items():  # 得到每个epoch i 的各个层的梯度
            if flag == 0: # 读的第一个epoch i 要新建立个数据矩阵的第一行，后续的epoch i都是在这行的基础上向下拓展
                summary_quan_dict[k] = total_quan_list[i]['grads'][k].reshape(1, -1)
                summary_base_dict[k] = total_base_list[i]['grads'][k].reshape(1, -1)

            else:
                # 字典里的数据不能直接改，需要重新赋值
                a = summary_quan_dict[k]
                b = total_quan_list[i]['grads'][k].reshape(1, -1)
                c = np.vstack((a, b))
                summary_quan_dict[k] = c

                a = summary_base_dict[k]
                b = total_base_list[i]['grads'][k].reshape(1, -1)
                c = np.vstack((a, b))
                summary_base_dict[k] = c

        flag = 1

        # loss = total_quan_list[i]['losses']
        # print(loss)

    # df = pd.read_csv('./data_analysis_folder/data.csv', index_col=0)
    # # df = pd.DataFrame()
    # df2 = pd.DataFrame()

    # 上面是在收集数据，下面才是求和


    for j in range(epoch):
        flag0 = 0  # 各个layer的weight和bias
        cnt = 0  # 依次遍历各个layer
        sum = 0 # sum只是对一个epoch j 的加权梯度分布相似度记录
        for k, _ in summary_quan_dict.items():
            w = summary_base_dict[k][j, :]  # 这里不合适 要改造
            v = summary_quan_dict[k][j, :]
            if flag0 == 0:
                cur_weight = weight_f1[cnt] * scipy.stats.wasserstein_distance(w, v)  # weight
                # 不是很方便存 需要三维了(sheet)
                # if tag == 1:
                #     df2[k] = [cur_weight]
                # else:
                #     df[k] = [cur_weight]

                sum += 0.99 * cur_weight
            else:
                cur_bias = weight_f1[cnt] * scipy.stats.wasserstein_distance(w, v)  # bias
                # if tag == 1:
                #     df2[k] = [cur_bias]
                # else:
                #     df[k] = [cur_bias]

                sum += 0.01 * cur_bias

            if flag0 == 1:
                cnt = cnt + 1
                flag0 = 0
            else:
                flag0 = 1

        sum = sum / (weight_f1.sum() * 2)
        result += sum  # 对各个epoch i的加权梯度相似度求和
        print(sum)
    result /= epoch  # 对epoch节点阶段内的梯度相似度求平均
    if tag == 1:
        df2[str(epoch)] = [result]
    else :
        df[str(epoch)] = [result]

    result = 0








if tag == 1 :
    df = df.append(df2)
    # df.to_csv('./qat_analysis_data/wasserstein_distance.csv')
    df.to_csv(csvpath)
else :
    # df.to_csv('./qat_analysis_data/wasserstein_distance.csv')
    df.to_csv(csvpath)






# f = open('lenet_ptq_wasserstein_similarity.txt','a')
# f.write('bit:' + str(d1) + ' epoch_num:' + str(d2) +': '+str(sum)+'\n')
# f.close()

