Commit a301bb7c by yuxguo

fix

parent 9eed4263
......@@ -380,7 +380,9 @@ def train(trainloader, optimizer, model, epoch):
"label": label
}
loss, logits = model(inputs)
print(loss.shape)
loss.backward()
exit()
optimizer.step()
pred = logits.argmax(dim=-1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment