Commit b1bd9ab3 by Werner Duvaud

Fix #6

parent 2ac21696
......@@ -24,7 +24,7 @@ class Trainer:
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_layers,
self.config.support_size
self.config.support_size,
)
self.model.set_weights(initial_weights)
self.model.to(torch.device(config.training_device))
......@@ -150,19 +150,21 @@ class Trainer:
See paper appendix Network Architecture
"""
# Reduce the scale (defined in https://arxiv.org/abs/1805.11593)
x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1 + 0.001 * x)
x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + 0.001 * x
# Encode on a vector
x = torch.clamp(x, -support_size, support_size)
floor = x.floor()
ceil = x.ceil()
prob = x - floor
logits = torch.zeros(x.shape[0], x.shape[1], 2 * support_size + 1).to(x.device)
logits.scatter_(
2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1)
)
indexes = (floor + support_size + 1)
prob = prob.masked_fill_(2 * support_size < indexes, 0.0)
indexes = indexes.masked_fill_(2 * support_size < indexes, 0.0)
logits.scatter_(
2, (ceil + support_size).long().unsqueeze(-1), prob.unsqueeze(-1)
2, indexes.long().unsqueeze(-1), prob.unsqueeze(-1)
)
return logits
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment