Commit 147e0fbb by Zhihong Ma

fix: before BN trial

parent e230b520
...@@ -279,11 +279,13 @@ class LeNet(nn.Module): ...@@ -279,11 +279,13 @@ class LeNet(nn.Module):
self.conv_layers = nn.ModuleDict({ self.conv_layers = nn.ModuleDict({
# block1 # block1
'conv1': nn.Conv2d(3,6,5), # (2*3*5*5) * 32*32*6 (bias占其中的32*32*6) 6144/921600 'conv1': nn.Conv2d(3,6,5), # (2*3*5*5) * 32*32*6 (bias占其中的32*32*6) 6144/921600
'bn1': nn.BatchNorm2d(6),
'reluc1': nn.ReLU(), 'reluc1': nn.ReLU(),
'pool1': nn.MaxPool2d(2,2), 'pool1': nn.MaxPool2d(2,2),
# block2 # block2
'conv2': nn.Conv2d(6,16,5), # (2*6*5*5) * 16*16*16 (bias占其中的16*16*6) 1536/1228800 'conv2': nn.Conv2d(6,16,5), # (2*6*5*5) * 16*16*16 (bias占其中的16*16*6) 1536/1228800
'bn2': nn.BatchNorm2d(16),
'reluc2': nn.ReLU(), 'reluc2': nn.ReLU(),
'pool2': nn.MaxPool2d(2,2), 'pool2': nn.MaxPool2d(2,2),
}) })
...@@ -316,11 +318,15 @@ class LeNet(nn.Module): ...@@ -316,11 +318,15 @@ class LeNet(nn.Module):
self.quantize_conv_layers=nn.ModuleDict({ self.quantize_conv_layers=nn.ModuleDict({
# qi=true: 前一层输出的结果是没有量化过的,需要量化。 maxpool和relu都不会影响INT和minmax,所以在这俩之后的层的pi是false # qi=true: 前一层输出的结果是没有量化过的,需要量化。 maxpool和relu都不会影响INT和minmax,所以在这俩之后的层的pi是false
#若前一层是conv,数据minmax被改变,则需要qi=true来量化 #若前一层是conv,数据minmax被改变,则需要qi=true来量化
'qconv1': QConv2d(self.conv_layers['conv1'], qi=True, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode), # 'qconv1': QConv2d(self.conv_layers['conv1'], qi=True, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
'qreluc1': QReLU(n_exp=self.n_exp, mode=self.mode), # 'qreluc1': QReLU(n_exp=self.n_exp, mode=self.mode),
'qconvbnrelu1': QConvBNReLU(self.conv_layers['conv1'],self.conv_layers['bn1'],qi=True,qo=True,num_bits=num_bits,n_exp=self.n_exp,mode=self.mode),
'qpool1': QMaxPooling2d(kernel_size=2,stride=2,padding=0, n_exp=self.n_exp, mode=self.mode), 'qpool1': QMaxPooling2d(kernel_size=2,stride=2,padding=0, n_exp=self.n_exp, mode=self.mode),
'qconv2': QConv2d(self.conv_layers['conv2'], qi=False, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
'qreluc2': QReLU(n_exp=self.n_exp, mode=self.mode), # 'qconv2': QConv2d(self.conv_layers['conv2'], qi=False, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
# 'qreluc2': QReLU(n_exp=self.n_exp, mode=self.mode),
'qconvbnrelu1': QConvBNReLU(self.conv_layers['conv2'],self.conv_layers['bn2'],qi=True,qo=True,num_bits=num_bits,n_exp=self.n_exp,mode=self.mode),
'qpool2': QMaxPooling2d(kernel_size=2, stride=2, padding=0, n_exp=self.n_exp, mode=self.mode) 'qpool2': QMaxPooling2d(kernel_size=2, stride=2, padding=0, n_exp=self.n_exp, mode=self.mode)
}) })
...@@ -347,37 +353,51 @@ class LeNet(nn.Module): ...@@ -347,37 +353,51 @@ class LeNet(nn.Module):
def freeze(self): def freeze(self):
self.quantize_conv_layers['qconv1'].freeze() # self.quantize_conv_layers['qconv1'].freeze()
self.quantize_conv_layers['qreluc1'].freeze(self.quantize_conv_layers['qconv1'].qo) # self.quantize_conv_layers['qreluc1'].freeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconv1'].qo) self.quantize_conv_layers['qconvbnrelu1'].freeze()
#self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconvbnrelu1'].qo)
# self.quantize_conv_layers['qconv2'].freeze(self.quantize_conv_layers['qconv1'].qo)
# self.quantize_conv_layers['qreluc2'].freeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qconvbnrelu2'].freeze()
# self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconvbnrelu2'].qo)
self.quantize_conv_layers['qconv2'].freeze(self.quantize_conv_layers['qconv1'].qo) # self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qreluc2'].freeze(self.quantize_conv_layers['qconv2'].qo) self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconvbnrelu2'].qo)
self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconv2'].qo)
self.quantize_fc_layers['qreluf1'].freeze(self.quantize_fc_layers['qfc1'].qo) self.quantize_fc_layers['qreluf1'].freeze(self.quantize_fc_layers['qfc1'].qo)
self.quantize_fc_layers['qfc2'].freeze(qi=self.quantize_fc_layers['qfc1'].qo) self.quantize_fc_layers['qfc2'].freeze(qi=self.quantize_fc_layers['qfc1'].qo)
self.quantize_fc_layers['qreluf2'].freeze(self.quantize_fc_layers['qfc2'].qo) self.quantize_fc_layers['qreluf2'].freeze(self.quantize_fc_layers['qfc2'].qo)
self.quantize_fc_layers['qfc3'].freeze(qi=self.quantize_fc_layers['qfc2'].qo) self.quantize_fc_layers['qfc3'].freeze(qi=self.quantize_fc_layers['qfc2'].qo)
def fakefreeze(self): def fakefreeze(self):
self.quantize_conv_layers['qconv1'].fakefreeze() # self.quantize_conv_layers['qconv1'].fakefreeze()
self.quantize_conv_layers['qreluc1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo) # self.quantize_conv_layers['qreluc1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qpool1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo) # self.quantize_conv_layers['qpool1'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qconvbnrelu1'].fakefreeze()
self.quantize_conv_layers['qpool1'].fakefreeze(self.quantize_conv_layers['qconvbnrelu1'].qo)
# self.quantize_conv_layers['qconv2'].fakefreeze(self.quantize_conv_layers['qconv1'].qo)
# self.quantize_conv_layers['qreluc2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo)
# self.quantize_conv_layers['qpool2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qconv2'].fakefreeze(self.quantize_conv_layers['qconv1'].qo) self.quantize_conv_layers['qconvbnrelu2'].fakefreeze()
self.quantize_conv_layers['qreluc2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo) self.quantize_conv_layers['qpool2'].fakefreeze(self.quantize_conv_layers['qconvbnrelu2'].qo)
self.quantize_conv_layers['qpool2'].fakefreeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_fc_layers['qfc1'].fakefreeze(qi=self.quantize_conv_layers['qconv2'].qo) # self.quantize_fc_layers['qfc1'].fakefreeze(qi=self.quantize_conv_layers['qconv2'].qo)
self.quantize_fc_layers['qfc1'].fakefreeze(qi=self.quantize_conv_layers['qconvbnrelu2'].qo)
self.quantize_fc_layers['qreluf1'].fakefreeze(self.quantize_fc_layers['qfc1'].qo) self.quantize_fc_layers['qreluf1'].fakefreeze(self.quantize_fc_layers['qfc1'].qo)
self.quantize_fc_layers['qfc2'].fakefreeze(qi=self.quantize_fc_layers['qfc1'].qo) self.quantize_fc_layers['qfc2'].fakefreeze(qi=self.quantize_fc_layers['qfc1'].qo)
self.quantize_fc_layers['qreluf2'].fakefreeze(self.quantize_fc_layers['qfc2'].qo) self.quantize_fc_layers['qreluf2'].fakefreeze(self.quantize_fc_layers['qfc2'].qo)
self.quantize_fc_layers['qfc3'].fakefreeze(qi=self.quantize_fc_layers['qfc2'].qo) self.quantize_fc_layers['qfc3'].fakefreeze(qi=self.quantize_fc_layers['qfc2'].qo)
def quantize_inference(self, x): def quantize_inference(self, x):
x = self.quantize_conv_layers['qconv1'].qi.quantize_tensor(x, self.mode) # x = self.quantize_conv_layers['qconv1'].qi.quantize_tensor(x, self.mode)
x = self.quantize_conv_layers['qconvbnrelu1'].qi.quantize_tensor(x, self.mode)
for s, layer in self.quantize_conv_layers.items(): for s, layer in self.quantize_conv_layers.items():
x = layer.quantize_inference(x) x = layer.quantize_inference(x)
......
...@@ -182,6 +182,10 @@ if __name__ == "__main__": ...@@ -182,6 +182,10 @@ if __name__ == "__main__":
model = resnet50().to(device) model = resnet50().to(device)
elif args.model == 'resnet152' : elif args.model == 'resnet152' :
model = resnet152().to(device) model = resnet152().to(device)
elif args.model == 'LeNet' :
model = LeNet().to(device)
elif args.model == 'NetBN' :
model = NetBN().to(device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
......
...@@ -11,6 +11,7 @@ from torchvision import datasets, transforms ...@@ -11,6 +11,7 @@ from torchvision import datasets, transforms
import os import os
import os.path as osp import os.path as osp
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from resnet import *
...@@ -46,17 +47,25 @@ def quantize_inference(model, test_loader, device): ...@@ -46,17 +47,25 @@ def quantize_inference(model, test_loader, device):
if __name__ == "__main__": if __name__ == "__main__":
d1 = sys.argv[1] # num_bits parser = argparse.ArgumentParser(description='PTQ Training')
d2 = sys.argv[2] # mode parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
d3 = sys.argv[3] # n_exp parser.add_argument('-n','--num_bits', default=8, type=int, metavar='BITS', help='number of bits')
parser.add_argument('-t','--mode', default=1, type=int, metavar='MODES', help='PTQ mode(1:INT 2:PoT 3:FP)')
parser.add_argument('-e','--n_exp', default=4, type=int, metavar='N_EXP', help='number of exp')
# d1 = sys.argv[1] # num_bits
# d2 = sys.argv[2] # mode
# d3 = sys.argv[3] # n_exp
# d1 = 8 # d1 = 8
# d2 = 3 # d2 = 3
# d3 = 4 # d3 = 4
batch_size = 32 args = parser.parse_args()
d1 = args.num_bits
d2 = args.mode
d3 = args.n_exp
batch_size = 128
using_bn = True using_bn = True
load_quant_model_file = None load_quant_model_file = None
# load_model_file = None # load_model_file = None
net = 'LeNet' # 1:
acc = 0 acc = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
...@@ -80,9 +89,21 @@ if __name__ == "__main__": ...@@ -80,9 +89,21 @@ if __name__ == "__main__":
) )
if using_bn: if using_bn:
# model = LeNet(n_exp=int(d3), mode=int(d2)).to(device)
if args.model == 'resnet18' :
model = resnet18(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'resnet50' :
model = resnet50(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'resnet152' :
model = resnet152(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'LeNet' :
model = LeNet(n_exp=int(d3), mode=int(d2)).to(device) model = LeNet(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'NetBN' :
model = NetBN().to(device)
# model = resnet18(n_exp=int(d3), mode=int(d2)).to(device)
# 生成梯度分布图的时候是从0开始训练的 # 生成梯度分布图的时候是从0开始训练的
model.load_state_dict(torch.load('./project/p/ckpt/cifar-10_lenet_bn.pt', map_location='cpu')) # model.load_state_dict(torch.load('./project/p/ckpt/cifar-10_lenet_bn.pt', map_location='cpu'))
model.load_state_dict(torch.load('./project/p/ckpt/' + args.model + '/' + args.model + '.pt', map_location='cpu'))
# else: # else:
# model = Net() # model = Net()
# model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu')) # model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
...@@ -91,7 +112,7 @@ if __name__ == "__main__": ...@@ -91,7 +112,7 @@ if __name__ == "__main__":
model.eval() model.eval()
full_inference(model, test_loader, device) full_inference(model, test_loader, device)
full_writer = SummaryWriter(log_dir='./project/p/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'full_log') full_writer = SummaryWriter(log_dir='./project/p/' + args.model +'/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'full_log')
for name, param in model.named_parameters(): for name, param in model.named_parameters():
full_writer.add_histogram(tag=name + '_data', values=param.data) full_writer.add_histogram(tag=name + '_data', values=param.data)
...@@ -99,7 +120,7 @@ if __name__ == "__main__": ...@@ -99,7 +120,7 @@ if __name__ == "__main__":
model.quantize(num_bits=num_bits) model.quantize(num_bits=num_bits)
model.eval() model.eval()
print('Quantization bit: %d' % num_bits) print('Quantization bit: %d' % num_bits)
writer = SummaryWriter(log_dir='./project/p/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'quant_bit_' + str(d1) + '_log') writer = SummaryWriter(log_dir='./project/p/'+ args.model + '/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'quant_bit_' + str(d1) + '_log')
if load_quant_model_file is not None: if load_quant_model_file is not None:
model.load_state_dict(torch.load(load_quant_model_file)) model.load_state_dict(torch.load(load_quant_model_file))
...@@ -114,12 +135,12 @@ if __name__ == "__main__": ...@@ -114,12 +135,12 @@ if __name__ == "__main__":
# 原PTQ mode=1时 # 原PTQ mode=1时
# save_file = 'ckpt/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt' # save_file = 'ckpt/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt'
dir_name ='./project/p/ckpt/mode'+ str(d2) + '_' + str(d3) + '/ptq' dir_name ='./project/p/ckpt/' + args.model + '/mode'+ str(d2) + '_' + str(d3) + '/ptq'
if not os.path.isdir(dir_name): if not os.path.isdir(dir_name):
os.makedirs(dir_name, mode=0o777) os.makedirs(dir_name, mode=0o777)
os.chmod(dir_name, mode=0o777) os.chmod(dir_name, mode=0o777)
save_file = './project/p/ckpt/mode'+ str(d2) + '_' + str(d3) + '/ptq' + '/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt' save_file = './project/p/ckpt/' + args.model + '/mode'+ str(d2) + '_' + str(d3) + '/ptq' + '/cifar-10_' + args.model + '_ptq_' + str(d1) + '_.pt'
torch.save(model.state_dict(), save_file) torch.save(model.state_dict(), save_file)
...@@ -130,7 +151,7 @@ if __name__ == "__main__": ...@@ -130,7 +151,7 @@ if __name__ == "__main__":
# print(model.qconv1.M.device) # print(model.qconv1.M.device)
acc = quantize_inference(model, test_loader, device) acc = quantize_inference(model, test_loader, device)
f = open('./project/p/lenet_ptq_acc' + '.txt', 'a') f = open('./project/p/' + args.model + '_ptq_acc' + '.txt', 'a')
f.write('bit ' + str(d1) + ': ' + str(acc) + '\n') f.write('bit ' + str(d1) + ': ' + str(acc) + '\n')
f.close() f.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment