Commit 173a1f85 by yuxguo

fix

parent 0759bda4
......@@ -383,7 +383,8 @@ def train(trainloader, optimizer, model, epoch):
# print(loss.shape)
# print(logits.shape)
# exit()
loss.sum().backward()
loss = loss.sum()
loss.backward()
optimizer.step()
pred = logits.argmax(dim=-1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment