Commit 67498b6b by yuxguo

fix

parent a301bb7c
...@@ -381,8 +381,10 @@ def train(trainloader, optimizer, model, epoch): ...@@ -381,8 +381,10 @@ def train(trainloader, optimizer, model, epoch):
} }
loss, logits = model(inputs) loss, logits = model(inputs)
print(loss.shape) print(loss.shape)
loss.backward() print(logits.shape)
exit() exit()
loss.backward()
optimizer.step() optimizer.step()
pred = logits.argmax(dim=-1) pred = logits.argmax(dim=-1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment