Commit 84297604 by yuxguo

fix

parent 80ee6423
......@@ -355,9 +355,9 @@ def get_dataloader():
valid = RAVENdataset(args.dataset_path, "val", args.val_figure_configurations, ORIGIN_IMAGE_SIZE, transform=transforms.Compose([ToTensor()]))
test = RAVENdataset(args.dataset_path, "test", args.test_figure_configurations, ORIGIN_IMAGE_SIZE, transform=transforms.Compose([ToTensor()]))
trainloader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=args.load_workers)
validloader = DataLoader(valid, batch_size=args.batch_size, shuffle=False, num_workers=args.load_workers)
testloader = DataLoader(test, batch_size=args.batch_size, shuffle=False, num_workers=args.load_workers)
trainloader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
validloader = DataLoader(valid, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
testloader = DataLoader(test, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
return trainloader, validloader, testloader
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment