Commit fb0d02c8 by Werner Duvaud

Add short term memory

parent b1bd9ab3
...@@ -12,10 +12,11 @@ class MuZeroConfig: ...@@ -12,10 +12,11 @@ class MuZeroConfig:
self.observation_shape = 4 # Dimensions of the game observation self.observation_shape = 4 # Dimensions of the game observation
self.action_space = [i for i in range(2)] # Fixed list of all possible actions self.action_space = [i for i in range(2)] # Fixed list of all possible actions
self.players = [i for i in range(1)] # List of players self.players = [i for i in range(1)] # List of players
self.stacked_observations = 3 # Number of previous observation to add to the current observation
### Self-Play ### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer self.num_actors = 3 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 500 # Maximum number of moves if game is not finished before self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward self.discount = 0.997 # Chronological discount of the reward
......
...@@ -12,6 +12,7 @@ class MuZeroConfig: ...@@ -12,6 +12,7 @@ class MuZeroConfig:
self.observation_shape = 6 * 7 # Dimensions of the game observation self.observation_shape = 6 * 7 # Dimensions of the game observation
self.action_space = [i for i in range(7)] # Fixed list of all possible actions self.action_space = [i for i in range(7)] # Fixed list of all possible actions
self.players = [i for i in range(2)] # List of players self.players = [i for i in range(2)] # List of players
self.stacked_observations = 2 # Number of previous observation to add to the current observation
### Self-Play ### Self-Play
...@@ -95,7 +96,7 @@ class Game: ...@@ -95,7 +96,7 @@ class Game:
The new observation, the reward and a boolean if the game has ended. The new observation, the reward and a boolean if the game has ended.
""" """
observation, reward, done = self.env.step(action) observation, reward, done = self.env.step(action)
return numpy.array(observation).flatten(), reward, done return observation, reward, done
def to_play(self): def to_play(self):
""" """
...@@ -126,7 +127,7 @@ class Game: ...@@ -126,7 +127,7 @@ class Game:
Returns: Returns:
Initial observation of the game. Initial observation of the game.
""" """
return numpy.array(self.env.reset()).flatten() return self.env.reset()
def close(self): def close(self):
""" """
......
...@@ -12,11 +12,12 @@ class MuZeroConfig: ...@@ -12,11 +12,12 @@ class MuZeroConfig:
self.observation_shape = 8 # Dimensions of the game observation self.observation_shape = 8 # Dimensions of the game observation
self.action_space = [i for i in range(4)] # Fixed list of all possible actions self.action_space = [i for i in range(4)] # Fixed list of all possible actions
self.players = [i for i in range(1)] # List of players self.players = [i for i in range(1)] # List of players
self.stacked_observations = 1 # Number of previous observation to add to the current observation
### Self-Play ### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 1000 # Maximum number of moves if game is not finished before self.max_moves = 800 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
...@@ -97,7 +98,7 @@ class Game: ...@@ -97,7 +98,7 @@ class Game:
The new observation, the reward and a boolean if the game has ended. The new observation, the reward and a boolean if the game has ended.
""" """
observation, reward, done, _ = self.env.step(action) observation, reward, done, _ = self.env.step(action)
return numpy.array(observation).flatten(), reward/5, done return observation, reward, done
def to_play(self): def to_play(self):
""" """
......
...@@ -2,9 +2,7 @@ import torch ...@@ -2,9 +2,7 @@ import torch
class FullyConnectedNetwork(torch.nn.Module): class FullyConnectedNetwork(torch.nn.Module):
def __init__( def __init__(self, input_size, layer_sizes, output_size, activation=None):
self, input_size, layer_sizes, output_size, activation=torch.nn.Tanh()
):
super(FullyConnectedNetwork, self).__init__() super(FullyConnectedNetwork, self).__init__()
sizes_list = layer_sizes.copy() sizes_list = layer_sizes.copy()
sizes_list.insert(0, input_size) sizes_list.insert(0, input_size)
...@@ -32,6 +30,7 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -32,6 +30,7 @@ class MuZeroNetwork(torch.nn.Module):
def __init__( def __init__(
self, self,
observation_size, observation_size,
stacked_observations,
action_space_size, action_space_size,
encoding_size, encoding_size,
hidden_layers, hidden_layers,
...@@ -42,28 +41,23 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -42,28 +41,23 @@ class MuZeroNetwork(torch.nn.Module):
self.full_support_size = 2 * support_size + 1 self.full_support_size = 2 * support_size + 1
self.representation_network = FullyConnectedNetwork( self.representation_network = FullyConnectedNetwork(
observation_size, [], encoding_size observation_size * (stacked_observations + 1), [], encoding_size
) )
self.dynamics_encoded_state_network = FullyConnectedNetwork( self.dynamics_encoded_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, hidden_layers, encoding_size encoding_size + self.action_space_size, hidden_layers, encoding_size
) )
# Gradient scaling (See paper appendix Training)
self.dynamics_encoded_state_network.register_backward_hook(
lambda module, grad_i, grad_o: (grad_i[0] * 0.5,)
)
self.dynamics_reward_network = FullyConnectedNetwork( self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, encoding_size + self.action_space_size,
hidden_layers, hidden_layers,
self.full_support_size, self.full_support_size,
activation=None,
) )
self.prediction_policy_network = FullyConnectedNetwork( self.prediction_policy_network = FullyConnectedNetwork(
encoding_size, [], self.action_space_size, activation=None encoding_size, [], self.action_space_size
) )
self.prediction_value_network = FullyConnectedNetwork( self.prediction_value_network = FullyConnectedNetwork(
encoding_size, [], self.full_support_size, activation=None encoding_size, [], self.full_support_size,
) )
def prediction(self, encoded_state): def prediction(self, encoded_state):
...@@ -72,16 +66,18 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -72,16 +66,18 @@ class MuZeroNetwork(torch.nn.Module):
return policy_logit, value return policy_logit, value
def representation(self, observation): def representation(self, observation):
encoded_state = self.representation_network(observation) encoded_state = self.representation_network(
# TODO: Try to remove tanh activation observation.view(observation.shape[0], -1)
)
# Scale encoded state between [0, 1] (See appendix paper Training) # Scale encoded state between [0, 1] (See appendix paper Training)
encoded_state = (encoded_state - torch.min(encoded_state)) / ( encoded_state_diff = encoded_state - encoded_state.min(1, keepdim=True)[0]
torch.max(encoded_state) - torch.min(encoded_state) encoded_state_normalized = (
encoded_state_diff / encoded_state_diff.max(1, keepdim=True)[0]
) )
return encoded_state return encoded_state
def dynamics(self, encoded_state, action): def dynamics(self, encoded_state, action):
# Stack encoded_state with one hot action (See paper appendix Network Architecture) # Stack encoded_state with a game specific one hot encoded action (See paper appendix Network Architecture)
action_one_hot = ( action_one_hot = (
torch.zeros((action.shape[0], self.action_space_size)) torch.zeros((action.shape[0], self.action_space_size))
.to(action.device) .to(action.device)
...@@ -91,20 +87,26 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -91,20 +87,26 @@ class MuZeroNetwork(torch.nn.Module):
x = torch.cat((encoded_state, action_one_hot), dim=1) x = torch.cat((encoded_state, action_one_hot), dim=1)
next_encoded_state = self.dynamics_encoded_state_network(x) next_encoded_state = self.dynamics_encoded_state_network(x)
# TODO: Try to remove tanh activation
# Scale encoded state between [0, 1] (See paper appendix Training) # Scale encoded state between [0, 1] (See paper appendix Training)
next_encoded_state = (next_encoded_state - torch.min(next_encoded_state)) / ( next_encoded_state_diff = (
torch.max(next_encoded_state) - torch.min(next_encoded_state) next_encoded_state - next_encoded_state.min(1, keepdim=True)[0]
)
next_encoded_state_normalized = (
next_encoded_state_diff / next_encoded_state_diff.max(1, keepdim=True)[0]
) )
reward = self.dynamics_reward_network(x) reward = self.dynamics_reward_network(x)
return next_encoded_state, reward return next_encoded_state_normalized, reward
def initial_inference(self, observation): def initial_inference(self, observation):
encoded_state = self.representation(observation) encoded_state = self.representation(observation)
policy_logit, value = self.prediction(encoded_state) policy_logit, value = self.prediction(encoded_state)
return ( return (
value, value,
torch.zeros(len(observation), self.full_support_size).to(observation.device), torch.zeros(len(observation), self.full_support_size).to(
observation.device
),
policy_logit, policy_logit,
encoded_state, encoded_state,
) )
......
...@@ -54,10 +54,11 @@ class MuZero: ...@@ -54,10 +54,11 @@ class MuZero:
# Initial weights used to initialize components # Initial weights used to initialize components
self.muzero_weights = models.MuZeroNetwork( self.muzero_weights = models.MuZeroNetwork(
self.config.observation_shape, self.config.observation_shape,
self.config.stacked_observations,
len(self.config.action_space), len(self.config.action_space),
self.config.encoding_size, self.config.encoding_size,
self.config.hidden_layers, self.config.hidden_layers,
self.config.support_size self.config.support_size,
).get_weights() ).get_weights()
def train(self): def train(self):
...@@ -137,7 +138,7 @@ class MuZero: ...@@ -137,7 +138,7 @@ class MuZero:
time.sleep(3) time.sleep(3)
except KeyboardInterrupt as err: except KeyboardInterrupt as err:
# Comment the line below to be able to stop the training but keep running # Comment the line below to be able to stop the training but keep running
# raise err raise err
pass pass
self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote()) self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
ray.shutdown() ray.shutdown()
...@@ -159,7 +160,9 @@ class MuZero: ...@@ -159,7 +160,9 @@ class MuZero:
) )
test_rewards = [] test_rewards = []
for _ in range(self.config.test_episodes): for _ in range(self.config.test_episodes):
history = ray.get(self_play_workers.play_game.remote(0, render, muzero_player)) history = ray.get(
self_play_workers.play_game.remote(0, render, muzero_player)
)
test_rewards.append(sum(history.rewards)) test_rewards.append(sum(history.rewards))
ray.shutdown() ray.shutdown()
return test_rewards return test_rewards
...@@ -183,7 +186,9 @@ if __name__ == "__main__": ...@@ -183,7 +186,9 @@ if __name__ == "__main__":
## Test ## Test
muzero.load_model() muzero.load_model()
# Render some self-played games # Render some self-played games
muzero.test(render=True, muzero_player=None) muzero.test(render=True, muzero_player=None)
# Let user play against MuZero (MuZero is player 0 here) # Let user play against MuZero (MuZero is player 0 here)
# muzero.test(render=True, muzero_player=0) # muzero.test(render=True, muzero_player=0)
...@@ -22,6 +22,7 @@ class SelfPlay: ...@@ -22,6 +22,7 @@ class SelfPlay:
# Initialize the network # Initialize the network
self.model = models.MuZeroNetwork( self.model = models.MuZeroNetwork(
self.config.observation_shape, self.config.observation_shape,
self.config.stacked_observations,
len(self.config.action_space), len(self.config.action_space),
self.config.encoding_size, self.config.encoding_size,
self.config.hidden_layers, self.config.hidden_layers,
...@@ -66,6 +67,9 @@ class SelfPlay: ...@@ -66,6 +67,9 @@ class SelfPlay:
""" """
game_history = GameHistory() game_history = GameHistory()
observation = self.game.reset() observation = self.game.reset()
observation = self.stack_previous_observations(
observation, game_history, self.config.stacked_observations
)
game_history.observation_history.append(observation) game_history.observation_history.append(observation)
done = False done = False
...@@ -75,7 +79,10 @@ class SelfPlay: ...@@ -75,7 +79,10 @@ class SelfPlay:
with torch.no_grad(): with torch.no_grad():
while not done and len(game_history.action_history) < self.config.max_moves: while not done and len(game_history.action_history) < self.config.max_moves:
current_player = self.game.to_play() current_player = self.game.to_play()
if play_against_human_player is None or play_against_human_player == current_player: if (
play_against_human_player is None
or play_against_human_player == current_player
):
root = MCTS(self.config).run( root = MCTS(self.config).run(
self.model, self.model,
observation, observation,
...@@ -88,10 +95,19 @@ class SelfPlay: ...@@ -88,10 +95,19 @@ class SelfPlay:
observation, reward, done = self.game.step(action) observation, reward, done = self.game.step(action)
if play_against_human_player is not None and current_player != play_against_human_player: if (
action = int(input("Enter the action of player {} : ".format(current_player))) play_against_human_player is not None
and current_player != play_against_human_player
):
action = int(
input("Enter the action of player {} : ".format(current_player))
)
observation, reward, done = self.game.step(action) observation, reward, done = self.game.step(action)
observation = self.stack_previous_observations(
observation, game_history, self.config.stacked_observations,
)
if render: if render:
print("Action : {}".format(action)) print("Action : {}".format(action))
self.game.render() self.game.render()
...@@ -105,6 +121,20 @@ class SelfPlay: ...@@ -105,6 +121,20 @@ class SelfPlay:
return game_history return game_history
@staticmethod @staticmethod
def stack_previous_observations(
observation, game_history, num_stacked_observations
):
stacked_observations = [observation]
for i in range(num_stacked_observations):
try:
stacked_observations.append(
game_history.observation_history[-(i + 1)][0]
)
except IndexError:
stacked_observations.append(numpy.zeros_like(observation))
return stacked_observations
@staticmethod
def select_action(node, temperature): def select_action(node, temperature):
""" """
Select action according to the vivist count distribution and the temperature. Select action according to the vivist count distribution and the temperature.
...@@ -151,7 +181,7 @@ class MCTS: ...@@ -151,7 +181,7 @@ class MCTS:
""" """
root = Node(0) root = Node(0)
observation = ( observation = (
torch.from_numpy(observation) torch.tensor(observation)
.float() .float()
.unsqueeze(0) .unsqueeze(0)
.to(next(model.parameters()).device) .to(next(model.parameters()).device)
......
...@@ -21,6 +21,7 @@ class Trainer: ...@@ -21,6 +21,7 @@ class Trainer:
# Initialize the network # Initialize the network
self.model = models.MuZeroNetwork( self.model = models.MuZeroNetwork(
self.config.observation_shape, self.config.observation_shape,
self.config.stacked_observations,
len(self.config.action_space), len(self.config.action_space),
self.config.encoding_size, self.config.encoding_size,
self.config.hidden_layers, self.config.hidden_layers,
...@@ -118,7 +119,7 @@ class Trainer: ...@@ -118,7 +119,7 @@ class Trainer:
loss = (value_loss + reward_loss + policy_loss).mean() loss = (value_loss + reward_loss + policy_loss).mean()
# Scale gradient by number of unroll steps (See paper Training appendix) # Scale gradient by number of unroll steps (See paper Training appendix)
loss.register_hook(lambda grad: grad * 1 / self.config.num_unroll_steps) loss.register_hook(lambda grad: grad / self.config.num_unroll_steps)
# Optimize # Optimize
self.optimizer.zero_grad() self.optimizer.zero_grad()
...@@ -160,20 +161,22 @@ class Trainer: ...@@ -160,20 +161,22 @@ class Trainer:
logits.scatter_( logits.scatter_(
2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1) 2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1)
) )
indexes = (floor + support_size + 1) indexes = floor + support_size + 1
prob = prob.masked_fill_(2 * support_size < indexes, 0.0) prob = prob.masked_fill_(2 * support_size < indexes, 0.0)
indexes = indexes.masked_fill_(2 * support_size < indexes, 0.0) indexes = indexes.masked_fill_(2 * support_size < indexes, 0.0)
logits.scatter_( logits.scatter_(2, indexes.long().unsqueeze(-1), prob.unsqueeze(-1))
2, indexes.long().unsqueeze(-1), prob.unsqueeze(-1)
)
return logits return logits
@staticmethod @staticmethod
def loss_function( def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy value, reward, policy_logits, target_value, target_reward, target_policy
): ):
# Cross-entropy had a better convergence than MSE # Cross-entropy had a better convergence than MSE
value_loss = (-target_value * torch.nn.LogSoftmax(dim=1)(value)).sum(1).mean() value_loss = (-target_value * torch.nn.LogSoftmax(dim=1)(value)).sum(1).mean()
reward_loss = (-target_reward * torch.nn.LogSoftmax(dim=1)(reward)).sum(1).mean() reward_loss = (
policy_loss = (-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum(1).mean() (-target_reward * torch.nn.LogSoftmax(dim=1)(reward)).sum(1).mean()
)
policy_loss = (
(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum(1).mean()
)
return value_loss, reward_loss, policy_loss return value_loss, reward_loss, policy_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment