Commit 67498b6b by yuxguo

fix

parent a301bb7c
......@@ -381,8 +381,10 @@ def train(trainloader, optimizer, model, epoch):
}
loss, logits = model(inputs)
print(loss.shape)
loss.backward()
print(logits.shape)
exit()
loss.backward()
optimizer.step()
pred = logits.argmax(dim=-1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment