Commit 0759bda4 by yuxguo

fix

parent 40d3539c
......@@ -380,10 +380,10 @@ def train(trainloader, optimizer, model, epoch):
"label": label
}
loss, logits = model(inputs)
print(loss.shape)
print(logits.shape)
exit()
loss.backward()
# print(loss.shape)
# print(logits.shape)
# exit()
loss.sum().backward()
optimizer.step()
pred = logits.argmax(dim=-1)
......@@ -454,7 +454,7 @@ def main():
model = get_model(args)
if args.use_gpu:
model = DataParallel(model, device_ids=[0, 1]).cuda()
model = DataParallel(model, device_ids=[0, 1, 2, 3]).cuda()
# model = model.cuda()
if args.weight_decay == 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment