Commit ab636756 by yuxguo

fix

parent 18237e29
...@@ -370,7 +370,7 @@ def train(trainloader, optimizer, model, epoch): ...@@ -370,7 +370,7 @@ def train(trainloader, optimizer, model, epoch):
counter = 0 counter = 0
for batch_idx, (images, label, meta_target) in enumerate(trainloader): for batch_idx, (images, label, meta_target) in enumerate(trainloader):
counter += 1 counter += 1
if args.use_cuda: if args.use_gpu:
images = images.cuda() images = images.cuda()
label = label.cuda() label = label.cuda()
meta_target = meta_target.cuda() meta_target = meta_target.cuda()
...@@ -401,7 +401,7 @@ def validate(validloader, model, epoch): ...@@ -401,7 +401,7 @@ def validate(validloader, model, epoch):
with torch.no_grad(): with torch.no_grad():
for batch_idx, (images, label, meta_target) in enumerate(validloader): for batch_idx, (images, label, meta_target) in enumerate(validloader):
counter += 1 counter += 1
if args.use_cuda: if args.use_gpu:
images = images.cuda() images = images.cuda()
label = label.cuda() label = label.cuda()
meta_target = meta_target.cuda() meta_target = meta_target.cuda()
...@@ -427,7 +427,7 @@ def test(testloader, model, epoch): ...@@ -427,7 +427,7 @@ def test(testloader, model, epoch):
with torch.no_grad(): with torch.no_grad():
for batch_idx, (images, label, meta_target) in enumerate(testloader): for batch_idx, (images, label, meta_target) in enumerate(testloader):
counter += 1 counter += 1
if args.use_cuda: if args.use_gpu:
images = images.cuda() images = images.cuda()
label = label.cuda() label = label.cuda()
meta_target = meta_target.cuda() meta_target = meta_target.cuda()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment