Commit b1bd9ab3 by Werner Duvaud

Fix #6

parent 2ac21696
...@@ -24,7 +24,7 @@ class Trainer: ...@@ -24,7 +24,7 @@ class Trainer:
len(self.config.action_space), len(self.config.action_space),
self.config.encoding_size, self.config.encoding_size,
self.config.hidden_layers, self.config.hidden_layers,
self.config.support_size self.config.support_size,
) )
self.model.set_weights(initial_weights) self.model.set_weights(initial_weights)
self.model.to(torch.device(config.training_device)) self.model.to(torch.device(config.training_device))
...@@ -150,19 +150,21 @@ class Trainer: ...@@ -150,19 +150,21 @@ class Trainer:
See paper appendix Network Architecture See paper appendix Network Architecture
""" """
# Reduce the scale (defined in https://arxiv.org/abs/1805.11593) # Reduce the scale (defined in https://arxiv.org/abs/1805.11593)
x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1 + 0.001 * x) x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + 0.001 * x
# Encode on a vector # Encode on a vector
x = torch.clamp(x, -support_size, support_size) x = torch.clamp(x, -support_size, support_size)
floor = x.floor() floor = x.floor()
ceil = x.ceil()
prob = x - floor prob = x - floor
logits = torch.zeros(x.shape[0], x.shape[1], 2 * support_size + 1).to(x.device) logits = torch.zeros(x.shape[0], x.shape[1], 2 * support_size + 1).to(x.device)
logits.scatter_( logits.scatter_(
2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1) 2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1)
) )
indexes = (floor + support_size + 1)
prob = prob.masked_fill_(2 * support_size < indexes, 0.0)
indexes = indexes.masked_fill_(2 * support_size < indexes, 0.0)
logits.scatter_( logits.scatter_(
2, (ceil + support_size).long().unsqueeze(-1), prob.unsqueeze(-1) 2, indexes.long().unsqueeze(-1), prob.unsqueeze(-1)
) )
return logits return logits
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment