Commit 9eed4263 by yuxguo

fix

parent ab636756
...@@ -102,7 +102,7 @@ trainer_args.add_argument('--epochs', '-e', type=int, default=200, ...@@ -102,7 +102,7 @@ trainer_args.add_argument('--epochs', '-e', type=int, default=200,
help='the number of epochs') help='the number of epochs')
trainer_args.add_argument('--obs-epochs', '-oe', type=int, default=5, trainer_args.add_argument('--obs-epochs', '-oe', type=int, default=5,
help='the number of sub epochs for observation stage') help='the number of sub epochs for observation stage')
trainer_args.add_argument('--batch-size', '-bs', type=int, default=1, trainer_args.add_argument('--batch-size', '-bs', type=int, default=32,
help='input batch size for training (default: 1)') help='input batch size for training (default: 1)')
trainer_args.add_argument('--eval-batch-size', '-ebs', type=int, default=1, trainer_args.add_argument('--eval-batch-size', '-ebs', type=int, default=1,
help='input batch size for evaluation (default: 32)') help='input batch size for evaluation (default: 32)')
...@@ -450,8 +450,8 @@ def main(): ...@@ -450,8 +450,8 @@ def main():
model = get_model(args) model = get_model(args)
if args.use_gpu: if args.use_gpu:
# model = DataParallel(model).cuda() model = DataParallel(model).cuda()
model = model.cuda() # model = model.cuda()
if args.weight_decay == 0: if args.weight_decay == 0:
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment