Commit 18237e29 by yuxguo

fix

parent 6188a6bb
...@@ -450,7 +450,8 @@ def main(): ...@@ -450,7 +450,8 @@ def main():
model = get_model(args) model = get_model(args)
if args.use_gpu: if args.use_gpu:
model = DataParallel(model).cuda() # model = DataParallel(model).cuda()
model = model.cuda()
if args.weight_decay == 0: if args.weight_decay == 0:
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment