Commit 0759bda4 by yuxguo

fix

parent 40d3539c
...@@ -380,10 +380,10 @@ def train(trainloader, optimizer, model, epoch): ...@@ -380,10 +380,10 @@ def train(trainloader, optimizer, model, epoch):
"label": label "label": label
} }
loss, logits = model(inputs) loss, logits = model(inputs)
print(loss.shape) # print(loss.shape)
print(logits.shape) # print(logits.shape)
exit() # exit()
loss.backward() loss.sum().backward()
optimizer.step() optimizer.step()
pred = logits.argmax(dim=-1) pred = logits.argmax(dim=-1)
...@@ -454,7 +454,7 @@ def main(): ...@@ -454,7 +454,7 @@ def main():
model = get_model(args) model = get_model(args)
if args.use_gpu: if args.use_gpu:
model = DataParallel(model, device_ids=[0, 1]).cuda() model = DataParallel(model, device_ids=[0, 1, 2, 3]).cuda()
# model = model.cuda() # model = model.cuda()
if args.weight_decay == 0: if args.weight_decay == 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment