Commit 2d9fa316 by Werner Duvaud

Add short term memory

parent fb0d02c8
......@@ -74,7 +74,7 @@ class MuZeroNetwork(torch.nn.Module):
encoded_state_normalized = (
encoded_state_diff / encoded_state_diff.max(1, keepdim=True)[0]
)
return encoded_state
return encoded_state_normalized
def dynamics(self, encoded_state, action):
# Stack encoded_state with a game specific one hot encoded action (See paper appendix Network Architecture)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment