Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
M
Model-Transfer-Adaptability
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
haoyifan
Model-Transfer-Adaptability
Commits
e230b520
Commit
e230b520
authored
Apr 07, 2023
by
Zhihong Ma
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix: modify optimizer, scheduler, data augumentation for higher acc
parent
12b90b92
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
3 deletions
+11
-3
mzh/new_train.py
+11
-3
No files found.
mzh/new_train.py
View file @
e230b520
...
...
@@ -6,7 +6,7 @@ from get_weight import *
from
torch.utils.tensorboard
import
SummaryWriter
from
torchvision
import
datasets
,
transforms
from
torchvision.datasets
import
CIFAR10
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
from
resnet
import
*
from
torchvision.transforms
import
transforms
# import models
...
...
@@ -135,7 +135,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
'-b'
,
'--batch_size'
,
default
=
128
,
type
=
int
,
metavar
=
'BATCH SIZE'
,
help
=
'mini-batch size (default: 128)'
)
parser
.
add_argument
(
'-j'
,
'--workers'
,
default
=
4
,
type
=
int
,
metavar
=
'WORKERS'
,
help
=
'number of data loading workers (default: 4)'
)
parser
.
add_argument
(
'-lr'
,
'--learning-rate'
,
default
=
0.001
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate'
,
dest
=
'lr'
)
parser
.
add_argument
(
'-wd'
,
'--weight_decay'
,
default
=
0.0001
,
type
=
float
,
metavar
=
'WD'
,
help
=
'lr schduler weight decay'
,
dest
=
'wd'
)
parser
.
add_argument
(
'-t'
,
'--test'
,
dest
=
'test'
,
action
=
'store_true'
,
help
=
'test model on test set'
)
# models = ['resnet18', 'resnet50', 'resnet152','resnet18']
...
...
@@ -150,6 +152,7 @@ if __name__ == "__main__":
print
(
batch_size
)
num_workers
=
args
.
workers
lr
=
args
.
lr
weight_decay
=
args
.
wd
best_acc
=
float
(
"-inf"
)
...
...
@@ -183,6 +186,8 @@ if __name__ == "__main__":
criterion
=
nn
.
CrossEntropyLoss
()
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
lr
)
# optimizer = optim.AdaBound(model.parameters(), lr=lr,
# weight_decay=weight_decay, final_lr=0.001*lr)
print
(
"ok!"
)
# 数据并行
...
...
@@ -199,6 +204,8 @@ if __name__ == "__main__":
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./project/p/data'
,
train
=
True
,
download
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
RandomCrop
(
32
,
padding
=
2
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
(
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
))
...
...
@@ -219,7 +226,8 @@ if __name__ == "__main__":
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# 学习率调度器
lr_scheduler
=
optim
.
lr_scheduler
.
StepLR
(
optimizer
,
step_size
=
5
,
gamma
=
0.1
)
# lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
lr_scheduler
=
CosineAnnealingLR
(
optimizer
,
T_max
=
num_epochs
)
# TensorBoard
...
...
@@ -228,7 +236,7 @@ if __name__ == "__main__":
writer
=
SummaryWriter
(
log_dir
=
'./project/p/models_log/'
+
args
.
model
+
'/full_log'
)
# Early Stopping 参数
patience
=
5
patience
=
30
count
=
0
# WARN
# save_dir = './project/p/ckpt/trail'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment