Commit fb0d02c8 by Werner Duvaud

Add short term memory

parent b1bd9ab3
......@@ -12,10 +12,11 @@ class MuZeroConfig:
self.observation_shape = 4 # Dimensions of the game observation
self.action_space = [i for i in range(2)] # Fixed list of all possible actions
self.players = [i for i in range(1)] # List of players
self.stacked_observations = 3 # Number of previous observation to add to the current observation
### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.num_actors = 3 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
......
......@@ -12,6 +12,7 @@ class MuZeroConfig:
self.observation_shape = 6 * 7 # Dimensions of the game observation
self.action_space = [i for i in range(7)] # Fixed list of all possible actions
self.players = [i for i in range(2)] # List of players
self.stacked_observations = 2 # Number of previous observation to add to the current observation
### Self-Play
......@@ -95,7 +96,7 @@ class Game:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done = self.env.step(action)
return numpy.array(observation).flatten(), reward, done
return observation, reward, done
def to_play(self):
"""
......@@ -126,7 +127,7 @@ class Game:
Returns:
Initial observation of the game.
"""
return numpy.array(self.env.reset()).flatten()
return self.env.reset()
def close(self):
"""
......
......@@ -12,11 +12,12 @@ class MuZeroConfig:
self.observation_shape = 8 # Dimensions of the game observation
self.action_space = [i for i in range(4)] # Fixed list of all possible actions
self.players = [i for i in range(1)] # List of players
self.stacked_observations = 1 # Number of previous observation to add to the current observation
### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 1000 # Maximum number of moves if game is not finished before
self.max_moves = 800 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
......@@ -97,7 +98,7 @@ class Game:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
return numpy.array(observation).flatten(), reward/5, done
return observation, reward, done
def to_play(self):
"""
......
......@@ -2,9 +2,7 @@ import torch
class FullyConnectedNetwork(torch.nn.Module):
def __init__(
self, input_size, layer_sizes, output_size, activation=torch.nn.Tanh()
):
def __init__(self, input_size, layer_sizes, output_size, activation=None):
super(FullyConnectedNetwork, self).__init__()
sizes_list = layer_sizes.copy()
sizes_list.insert(0, input_size)
......@@ -32,6 +30,7 @@ class MuZeroNetwork(torch.nn.Module):
def __init__(
self,
observation_size,
stacked_observations,
action_space_size,
encoding_size,
hidden_layers,
......@@ -42,28 +41,23 @@ class MuZeroNetwork(torch.nn.Module):
self.full_support_size = 2 * support_size + 1
self.representation_network = FullyConnectedNetwork(
observation_size, [], encoding_size
observation_size * (stacked_observations + 1), [], encoding_size
)
self.dynamics_encoded_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, hidden_layers, encoding_size
)
# Gradient scaling (See paper appendix Training)
self.dynamics_encoded_state_network.register_backward_hook(
lambda module, grad_i, grad_o: (grad_i[0] * 0.5,)
)
self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_size,
hidden_layers,
self.full_support_size,
activation=None,
)
self.prediction_policy_network = FullyConnectedNetwork(
encoding_size, [], self.action_space_size, activation=None
encoding_size, [], self.action_space_size
)
self.prediction_value_network = FullyConnectedNetwork(
encoding_size, [], self.full_support_size, activation=None
encoding_size, [], self.full_support_size,
)
def prediction(self, encoded_state):
......@@ -72,16 +66,18 @@ class MuZeroNetwork(torch.nn.Module):
return policy_logit, value
def representation(self, observation):
encoded_state = self.representation_network(observation)
# TODO: Try to remove tanh activation
encoded_state = self.representation_network(
observation.view(observation.shape[0], -1)
)
# Scale encoded state between [0, 1] (See appendix paper Training)
encoded_state = (encoded_state - torch.min(encoded_state)) / (
torch.max(encoded_state) - torch.min(encoded_state)
encoded_state_diff = encoded_state - encoded_state.min(1, keepdim=True)[0]
encoded_state_normalized = (
encoded_state_diff / encoded_state_diff.max(1, keepdim=True)[0]
)
return encoded_state
def dynamics(self, encoded_state, action):
# Stack encoded_state with one hot action (See paper appendix Network Architecture)
# Stack encoded_state with a game specific one hot encoded action (See paper appendix Network Architecture)
action_one_hot = (
torch.zeros((action.shape[0], self.action_space_size))
.to(action.device)
......@@ -91,20 +87,26 @@ class MuZeroNetwork(torch.nn.Module):
x = torch.cat((encoded_state, action_one_hot), dim=1)
next_encoded_state = self.dynamics_encoded_state_network(x)
# TODO: Try to remove tanh activation
# Scale encoded state between [0, 1] (See paper appendix Training)
next_encoded_state = (next_encoded_state - torch.min(next_encoded_state)) / (
torch.max(next_encoded_state) - torch.min(next_encoded_state)
next_encoded_state_diff = (
next_encoded_state - next_encoded_state.min(1, keepdim=True)[0]
)
next_encoded_state_normalized = (
next_encoded_state_diff / next_encoded_state_diff.max(1, keepdim=True)[0]
)
reward = self.dynamics_reward_network(x)
return next_encoded_state, reward
return next_encoded_state_normalized, reward
def initial_inference(self, observation):
encoded_state = self.representation(observation)
policy_logit, value = self.prediction(encoded_state)
return (
value,
torch.zeros(len(observation), self.full_support_size).to(observation.device),
torch.zeros(len(observation), self.full_support_size).to(
observation.device
),
policy_logit,
encoded_state,
)
......
......@@ -54,10 +54,11 @@ class MuZero:
# Initial weights used to initialize components
self.muzero_weights = models.MuZeroNetwork(
self.config.observation_shape,
self.config.stacked_observations,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_layers,
self.config.support_size
self.config.support_size,
).get_weights()
def train(self):
......@@ -137,7 +138,7 @@ class MuZero:
time.sleep(3)
except KeyboardInterrupt as err:
# Comment the line below to be able to stop the training but keep running
# raise err
raise err
pass
self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
ray.shutdown()
......@@ -159,7 +160,9 @@ class MuZero:
)
test_rewards = []
for _ in range(self.config.test_episodes):
history = ray.get(self_play_workers.play_game.remote(0, render, muzero_player))
history = ray.get(
self_play_workers.play_game.remote(0, render, muzero_player)
)
test_rewards.append(sum(history.rewards))
ray.shutdown()
return test_rewards
......@@ -183,7 +186,9 @@ if __name__ == "__main__":
## Test
muzero.load_model()
# Render some self-played games
muzero.test(render=True, muzero_player=None)
# Let user play against MuZero (MuZero is player 0 here)
# muzero.test(render=True, muzero_player=0)
......@@ -22,6 +22,7 @@ class SelfPlay:
# Initialize the network
self.model = models.MuZeroNetwork(
self.config.observation_shape,
self.config.stacked_observations,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_layers,
......@@ -66,6 +67,9 @@ class SelfPlay:
"""
game_history = GameHistory()
observation = self.game.reset()
observation = self.stack_previous_observations(
observation, game_history, self.config.stacked_observations
)
game_history.observation_history.append(observation)
done = False
......@@ -75,7 +79,10 @@ class SelfPlay:
with torch.no_grad():
while not done and len(game_history.action_history) < self.config.max_moves:
current_player = self.game.to_play()
if play_against_human_player is None or play_against_human_player == current_player:
if (
play_against_human_player is None
or play_against_human_player == current_player
):
root = MCTS(self.config).run(
self.model,
observation,
......@@ -88,10 +95,19 @@ class SelfPlay:
observation, reward, done = self.game.step(action)
if play_against_human_player is not None and current_player != play_against_human_player:
action = int(input("Enter the action of player {} : ".format(current_player)))
if (
play_against_human_player is not None
and current_player != play_against_human_player
):
action = int(
input("Enter the action of player {} : ".format(current_player))
)
observation, reward, done = self.game.step(action)
observation = self.stack_previous_observations(
observation, game_history, self.config.stacked_observations,
)
if render:
print("Action : {}".format(action))
self.game.render()
......@@ -105,6 +121,20 @@ class SelfPlay:
return game_history
@staticmethod
def stack_previous_observations(
observation, game_history, num_stacked_observations
):
stacked_observations = [observation]
for i in range(num_stacked_observations):
try:
stacked_observations.append(
game_history.observation_history[-(i + 1)][0]
)
except IndexError:
stacked_observations.append(numpy.zeros_like(observation))
return stacked_observations
@staticmethod
def select_action(node, temperature):
"""
Select action according to the vivist count distribution and the temperature.
......@@ -151,7 +181,7 @@ class MCTS:
"""
root = Node(0)
observation = (
torch.from_numpy(observation)
torch.tensor(observation)
.float()
.unsqueeze(0)
.to(next(model.parameters()).device)
......
......@@ -21,6 +21,7 @@ class Trainer:
# Initialize the network
self.model = models.MuZeroNetwork(
self.config.observation_shape,
self.config.stacked_observations,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_layers,
......@@ -118,7 +119,7 @@ class Trainer:
loss = (value_loss + reward_loss + policy_loss).mean()
# Scale gradient by number of unroll steps (See paper Training appendix)
loss.register_hook(lambda grad: grad * 1 / self.config.num_unroll_steps)
loss.register_hook(lambda grad: grad / self.config.num_unroll_steps)
# Optimize
self.optimizer.zero_grad()
......@@ -160,12 +161,10 @@ class Trainer:
logits.scatter_(
2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1)
)
indexes = (floor + support_size + 1)
indexes = floor + support_size + 1
prob = prob.masked_fill_(2 * support_size < indexes, 0.0)
indexes = indexes.masked_fill_(2 * support_size < indexes, 0.0)
logits.scatter_(
2, indexes.long().unsqueeze(-1), prob.unsqueeze(-1)
)
logits.scatter_(2, indexes.long().unsqueeze(-1), prob.unsqueeze(-1))
return logits
@staticmethod
......@@ -174,6 +173,10 @@ class Trainer:
):
# Cross-entropy had a better convergence than MSE
value_loss = (-target_value * torch.nn.LogSoftmax(dim=1)(value)).sum(1).mean()
reward_loss = (-target_reward * torch.nn.LogSoftmax(dim=1)(reward)).sum(1).mean()
policy_loss = (-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum(1).mean()
reward_loss = (
(-target_reward * torch.nn.LogSoftmax(dim=1)(reward)).sum(1).mean()
)
policy_loss = (
(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum(1).mean()
)
return value_loss, reward_loss, policy_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment