Commit 91d53d31 by Zhihong Ma

Delete data_analysis_mmd.py

parent 70d2f7fe
# -*- coding: utf-8 -*-
import numpy
import numpy as np
import torch
import sys
from mmd_loss import *
from collections import OrderedDict
d1 = sys.argv[1] # bit
d2 = sys.argv[2] # epoch
# d1=4
# d2=5
sum=0
flag=0
total_quan_list=list()
total_base_list=list()
# CNN FLOPs = Cout * Hout * Wout * (2 * Cin * K * K ) 是考虑bias 否则-1
# FCN FLOPs = Cout * Cin 是考虑bias 否则-1
# 把相关的relu,pool也考虑进去了
# MAdd
# weight0 =np.array( [ 705600.0+4704.0+ 3528.0 , 480000.0+ 1600.0 + 1200.0 , 95880.0 + 120.0,
# 20076.0 + 84.0 , 1670.0 ])
# weight1=np.array([705,600.0 , 480,000.0,+ 95,880.0 ,
# 20,076.0 , 1,670.0 ])
# flops
weight_f0= np.array([357504+4704+4704, 241600+1600+1600,48000+120,10080+84,840])
weight_f1=np.array([357504, 241600,48000,10080,840])
summary_quan_dict=OrderedDict()
summary_base_dict=OrderedDict()
losses=[]
# 最外层:不同epoch的字典 内层:各个网络层的grads
for i in range(int(d2)):
total_quan_list.append(torch.load('./project/p/checkpoint/cifar-10_lenet_bn_quant/' + str(d1) + '/ckpt_cifar-10_lenet_bn_quant_'+str(i+1)+'.pth'))
#total_quan_list.append(torch.load('checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(d2) + '.pth'))
total_base_list.append(torch.load('./project/p/checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(i+1) + '.pth'))
for k, _ in total_base_list[i]['grads'].items():
if flag == 0:
summary_quan_dict[k] = total_quan_list[i]['grads'][k].reshape(1,-1)
summary_base_dict[k] = total_base_list[i]['grads'][k].reshape(1,-1)
else :
# 字典里的数据不能直接改,需要重新赋值
a=summary_quan_dict[k]
b=total_quan_list[i]['grads'][k].reshape(1,-1)
c=np.vstack((a,b))
summary_quan_dict[k] = c
a = summary_base_dict[k]
b = total_base_list[i]['grads'][k].reshape(1,-1)
c = np.vstack((a, b))
summary_base_dict[k] = c
flag = 1
cnt = 0
flag = 0
for k, _ in summary_quan_dict.items():
if flag == 0:
sum += 0.99*weight_f1[cnt] * MK_MMD(source=summary_base_dict[k], target=summary_quan_dict[k]) # weight
else:
sum += 0.01*weight_f1[cnt] * MK_MMD(source=summary_base_dict[k], target=summary_quan_dict[k]) #bias
if flag == 1:
cnt = cnt + 1
flag = 0
else:
flag=1
sum=sum/(weight_f0.sum()*2)
print(sum)
f = open('./project/p/lenet_ptq_similarity.txt','a')
f.write('bit:' + str(d1) + ' epoch_num:' + str(d2) +': '+str(sum)+'\n')
f.close()
# for k,v in summary_base_dict.items():
# if k== 'conv_layers.conv1.weight':
# print(v)
# print('===========')
# print(summary_quan_dict[k])
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment