Commit 147e0fbb by Zhihong Ma

fix: before BN trial

parent e230b520
......@@ -279,11 +279,13 @@ class LeNet(nn.Module):
self.conv_layers = nn.ModuleDict({
# block1
'conv1': nn.Conv2d(3,6,5), # (2*3*5*5) * 32*32*6 (bias占其中的32*32*6) 6144/921600
'bn1': nn.BatchNorm2d(6),
'reluc1': nn.ReLU(),
'pool1': nn.MaxPool2d(2,2),
# block2
'conv2': nn.Conv2d(6,16,5), # (2*6*5*5) * 16*16*16 (bias占其中的16*16*6) 1536/1228800
'bn2': nn.BatchNorm2d(16),
'reluc2': nn.ReLU(),
'pool2': nn.MaxPool2d(2,2),
})
......@@ -316,11 +318,15 @@ class LeNet(nn.Module):
self.quantize_conv_layers=nn.ModuleDict({
# qi=true: 前一层输出的结果是没有量化过的,需要量化。 maxpool和relu都不会影响INT和minmax,所以在这俩之后的层的pi是false
#若前一层是conv,数据minmax被改变,则需要qi=true来量化
'qconv1': QConv2d(self.conv_layers['conv1'], qi=True, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
'qreluc1': QReLU(n_exp=self.n_exp, mode=self.mode),
# 'qconv1': QConv2d(self.conv_layers['conv1'], qi=True, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
# 'qreluc1': QReLU(n_exp=self.n_exp, mode=self.mode),
'qconvbnrelu1': QConvBNReLU(self.conv_layers['conv1'],self.conv_layers['bn1'],qi=True,qo=True,num_bits=num_bits,n_exp=self.n_exp,mode=self.mode),
'qpool1': QMaxPooling2d(kernel_size=2,stride=2,padding=0, n_exp=self.n_exp, mode=self.mode),
'qconv2': QConv2d(self.conv_layers['conv2'], qi=False, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
'qreluc2': QReLU(n_exp=self.n_exp, mode=self.mode),
# 'qconv2': QConv2d(self.conv_layers['conv2'], qi=False, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
# 'qreluc2': QReLU(n_exp=self.n_exp, mode=self.mode),
'qconvbnrelu1': QConvBNReLU(self.conv_layers['conv2'],self.conv_layers['bn2'],qi=True,qo=True,num_bits=num_bits,n_exp=self.n_exp,mode=self.mode),
'qpool2': QMaxPooling2d(kernel_size=2, stride=2, padding=0, n_exp=self.n_exp, mode=self.mode)
})
......@@ -347,37 +353,51 @@ class LeNet(nn.Module):
def freeze(self):
self.quantize_conv_layers['qconv1'].freeze()
self.quantize_conv_layers['qreluc1'].freeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconv1'].qo)
# self.quantize_conv_layers['qconv1'].freeze()
# self.quantize_conv_layers['qreluc1'].freeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qconvbnrelu1'].freeze()
#self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconvbnrelu1'].qo)
# self.quantize_conv_layers['qconv2'].freeze(self.quantize_conv_layers['qconv1'].qo)
# self.quantize_conv_layers['qreluc2'].freeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qconvbnrelu2'].freeze()
# self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconvbnrelu2'].qo)
self.quantize_conv_layers['qconv2'].freeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qreluc2'].freeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconv2'].qo)
# self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconv2'].qo)
self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconvbnrelu2'].qo)
self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconv2'].qo)
self.quantize_fc_layers['qreluf1'].freeze(self.quantize_fc_layers['qfc1'].qo)
self.quantize_fc_layers['qfc2'].freeze(qi=self.quantize_fc_layers['qfc1'].qo)
self.quantize_fc_layers['qreluf2'].freeze(self.quantize_fc_layers['qfc2'].qo)
self.quantize_fc_layers['qfc3'].freeze(qi=self.quantize_fc_layers['qfc2'].qo)
def fakefreeze(self):
self.quantize_conv_layers['qconv1'].fakefreeze()
self.quantize_conv_layers['qreluc1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qpool1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
# self.quantize_conv_layers['qconv1'].fakefreeze()
# self.quantize_conv_layers['qreluc1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
# self.quantize_conv_layers['qpool1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qconvbnrelu1'].fakefreeze()
self.quantize_conv_layers['qpool1'].fakefreeze(self.quantize_conv_layers['qconvbnrelu1'].qo)
# self.quantize_conv_layers['qconv2'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
# self.quantize_conv_layers['qreluc2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo)
# self.quantize_conv_layers['qpool2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qconv2'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qreluc2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qpool2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qconvbnrelu2'].fakefreeze()
self.quantize_conv_layers['qpool2'].fakefreeze(self.quantize_conv_layers['qconvbnrelu2'].qo)
self.quantize_fc_layers['qfc1'].fakefreeze(qi=self.quantize_conv_layers['qconv2'].qo)
# self.quantize_fc_layers['qfc1'].fakefreeze(qi=self.quantize_conv_layers['qconv2'].qo)
self.quantize_fc_layers['qfc1'].fakefreeze(qi=self.quantize_conv_layers['qconvbnrelu2'].qo)
self.quantize_fc_layers['qreluf1'].fakefreeze(self.quantize_fc_layers['qfc1'].qo)
self.quantize_fc_layers['qfc2'].fakefreeze(qi=self.quantize_fc_layers['qfc1'].qo)
self.quantize_fc_layers['qreluf2'].fakefreeze(self.quantize_fc_layers['qfc2'].qo)
self.quantize_fc_layers['qfc3'].fakefreeze(qi=self.quantize_fc_layers['qfc2'].qo)
def quantize_inference(self, x):
x = self.quantize_conv_layers['qconv1'].qi.quantize_tensor(x, self.mode)
# x = self.quantize_conv_layers['qconv1'].qi.quantize_tensor(x, self.mode)
x = self.quantize_conv_layers['qconvbnrelu1'].qi.quantize_tensor(x, self.mode)
for s, layer in self.quantize_conv_layers.items():
x = layer.quantize_inference(x)
......
......@@ -182,6 +182,10 @@ if __name__ == "__main__":
model = resnet50().to(device)
elif args.model == 'resnet152' :
model = resnet152().to(device)
elif args.model == 'LeNet' :
model = LeNet().to(device)
elif args.model == 'NetBN' :
model = NetBN().to(device)
criterion = nn.CrossEntropyLoss()
......
......@@ -11,6 +11,7 @@ from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
from resnet import *
......@@ -46,17 +47,25 @@ def quantize_inference(model, test_loader, device):
if __name__ == "__main__":
d1 = sys.argv[1] # num_bits
d2 = sys.argv[2] # mode
d3 = sys.argv[3] # n_exp
parser = argparse.ArgumentParser(description='PTQ Training')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
parser.add_argument('-n','--num_bits', default=8, type=int, metavar='BITS', help='number of bits')
parser.add_argument('-t','--mode', default=1, type=int, metavar='MODES', help='PTQ mode(1:INT 2:PoT 3:FP)')
parser.add_argument('-e','--n_exp', default=4, type=int, metavar='N_EXP', help='number of exp')
# d1 = sys.argv[1] # num_bits
# d2 = sys.argv[2] # mode
# d3 = sys.argv[3] # n_exp
# d1 = 8
# d2 = 3
# d3 = 4
batch_size = 32
args = parser.parse_args()
d1 = args.num_bits
d2 = args.mode
d3 = args.n_exp
batch_size = 128
using_bn = True
load_quant_model_file = None
# load_model_file = None
net = 'LeNet' # 1:
acc = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
......@@ -80,9 +89,21 @@ if __name__ == "__main__":
)
if using_bn:
model = LeNet(n_exp=int(d3), mode=int(d2)).to(device)
# model = LeNet(n_exp=int(d3), mode=int(d2)).to(device)
if args.model == 'resnet18' :
model = resnet18(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'resnet50' :
model = resnet50(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'resnet152' :
model = resnet152(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'LeNet' :
model = LeNet(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'NetBN' :
model = NetBN().to(device)
# model = resnet18(n_exp=int(d3), mode=int(d2)).to(device)
# 生成梯度分布图的时候是从0开始训练的
model.load_state_dict(torch.load('./project/p/ckpt/cifar-10_lenet_bn.pt', map_location='cpu'))
# model.load_state_dict(torch.load('./project/p/ckpt/cifar-10_lenet_bn.pt', map_location='cpu'))
model.load_state_dict(torch.load('./project/p/ckpt/' + args.model + '/' + args.model + '.pt', map_location='cpu'))
# else:
# model = Net()
# model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
......@@ -91,7 +112,7 @@ if __name__ == "__main__":
model.eval()
full_inference(model, test_loader, device)
full_writer = SummaryWriter(log_dir='./project/p/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'full_log')
full_writer = SummaryWriter(log_dir='./project/p/' + args.model +'/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'full_log')
for name, param in model.named_parameters():
full_writer.add_histogram(tag=name + '_data', values=param.data)
......@@ -99,7 +120,7 @@ if __name__ == "__main__":
model.quantize(num_bits=num_bits)
model.eval()
print('Quantization bit: %d' % num_bits)
writer = SummaryWriter(log_dir='./project/p/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'quant_bit_' + str(d1) + '_log')
writer = SummaryWriter(log_dir='./project/p/'+ args.model + '/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'quant_bit_' + str(d1) + '_log')
if load_quant_model_file is not None:
model.load_state_dict(torch.load(load_quant_model_file))
......@@ -114,12 +135,12 @@ if __name__ == "__main__":
# 原PTQ mode=1时
# save_file = 'ckpt/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt'
dir_name ='./project/p/ckpt/mode'+ str(d2) + '_' + str(d3) + '/ptq'
dir_name ='./project/p/ckpt/' + args.model + '/mode'+ str(d2) + '_' + str(d3) + '/ptq'
if not os.path.isdir(dir_name):
os.makedirs(dir_name, mode=0o777)
os.chmod(dir_name, mode=0o777)
save_file = './project/p/ckpt/mode'+ str(d2) + '_' + str(d3) + '/ptq' + '/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt'
save_file = './project/p/ckpt/' + args.model + '/mode'+ str(d2) + '_' + str(d3) + '/ptq' + '/cifar-10_' + args.model + '_ptq_' + str(d1) + '_.pt'
torch.save(model.state_dict(), save_file)
......@@ -130,7 +151,7 @@ if __name__ == "__main__":
# print(model.qconv1.M.device)
acc = quantize_inference(model, test_loader, device)
f = open('./project/p/lenet_ptq_acc' + '.txt', 'a')
f = open('./project/p/' + args.model + '_ptq_acc' + '.txt', 'a')
f.write('bit ' + str(d1) + ': ' + str(acc) + '\n')
f.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment