Commit e230b520 by Zhihong Ma

fix: modify optimizer, scheduler, data augumentation for higher acc

parent 12b90b92
...@@ -6,7 +6,7 @@ from get_weight import * ...@@ -6,7 +6,7 @@ from get_weight import *
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms from torchvision import datasets, transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import CosineAnnealingLR
from resnet import * from resnet import *
from torchvision.transforms import transforms from torchvision.transforms import transforms
# import models # import models
...@@ -135,8 +135,10 @@ if __name__ == "__main__": ...@@ -135,8 +135,10 @@ if __name__ == "__main__":
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)') parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
parser.add_argument('-j','--workers', default=4, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)') parser.add_argument('-j','--workers', default=4, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('-wd','--weight_decay',default=0.0001,type=float,metavar='WD',help='lr schduler weight decay',dest='wd')
parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set') parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set')
# models = ['resnet18', 'resnet50', 'resnet152','resnet18'] # models = ['resnet18', 'resnet50', 'resnet152','resnet18']
...@@ -150,6 +152,7 @@ if __name__ == "__main__": ...@@ -150,6 +152,7 @@ if __name__ == "__main__":
print(batch_size) print(batch_size)
num_workers = args.workers num_workers = args.workers
lr = args.lr lr = args.lr
weight_decay = args.wd
best_acc = float("-inf") best_acc = float("-inf")
...@@ -183,6 +186,8 @@ if __name__ == "__main__": ...@@ -183,6 +186,8 @@ if __name__ == "__main__":
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr) optimizer = optim.Adam(model.parameters(), lr=lr)
# optimizer = optim.AdaBound(model.parameters(), lr=lr,
# weight_decay=weight_decay, final_lr=0.001*lr)
print("ok!") print("ok!")
# 数据并行 # 数据并行
...@@ -199,6 +204,8 @@ if __name__ == "__main__": ...@@ -199,6 +204,8 @@ if __name__ == "__main__":
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./project/p/data', train=True, download=False, datasets.CIFAR10('./project/p/data', train=True, download=False,
transform=transforms.Compose([ transform=transforms.Compose([
transforms.RandomCrop(32, padding=2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
...@@ -219,7 +226,8 @@ if __name__ == "__main__": ...@@ -219,7 +226,8 @@ if __name__ == "__main__":
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) # test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# 学习率调度器 # 学习率调度器
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) # lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
# TensorBoard # TensorBoard
...@@ -228,7 +236,7 @@ if __name__ == "__main__": ...@@ -228,7 +236,7 @@ if __name__ == "__main__":
writer = SummaryWriter(log_dir='./project/p/models_log/' + args.model + '/full_log') writer = SummaryWriter(log_dir='./project/p/models_log/' + args.model + '/full_log')
# Early Stopping 参数 # Early Stopping 参数
patience = 5 patience = 30
count = 0 count = 0
# WARN # WARN
# save_dir = './project/p/ckpt/trail' # save_dir = './project/p/ckpt/trail'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment