Commit c5f6bffc by Werner Duvaud

Add gradient scaling and update hyperparameters

parent ceb98db4
...@@ -6,16 +6,18 @@ class MuZeroConfig: ...@@ -6,16 +6,18 @@ class MuZeroConfig:
def __init__(self): def __init__(self):
self.seed = 0 # Seed for numpy, torch and the game self.seed = 0 # Seed for numpy, torch and the game
### Game ### Game
self.observation_shape = 4 # Dimensions of the game observation self.observation_shape = 4 # Dimensions of the game observation
self.action_space = [i for i in range(2)] # Fixed list of all possible actions self.action_space = [i for i in range(2)] # Fixed list of all possible actions (float between 0 and 1)
### Self-Play ### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 500 # Maximum number of moves if game is not finished before self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated self.num_simulations = 80 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = None # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A) self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
# Root prior exploration noise # Root prior exploration noise
self.root_dirichlet_alpha = 0.25 self.root_dirichlet_alpha = 0.25
...@@ -25,30 +27,33 @@ class MuZeroConfig: ...@@ -25,30 +27,33 @@ class MuZeroConfig:
self.pb_c_base = 19652 self.pb_c_base = 19652
self.pb_c_init = 1.25 self.pb_c_init = 1.25
### Network ### Network
self.encoding_size = 64 self.encoding_size = 32
self.hidden_size = 32 self.hidden_size = 64
# Training ### Training
self.results_path = "./pretrained" # Path to store the model weights self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 2000 # Total number of training steps (ie weights update according to a batch) self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 1 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A) self.training_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
self.weight_decay = 1e-4 # L2 weights regularization self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9 self.momentum = 0.9
# Test
self.test_episodes = 2 # Number of game played to evaluate the network
# Exponential learning rate schedule # Exponential learning rate schedule
self.lr_init = 0.0005 # Initial learning rate self.lr_init = 0.008 # Initial learning rate
self.lr_decay_rate = 0.1 self.lr_decay_rate = 0.01
self.lr_decay_steps = 3500 self.lr_decay_steps = 10000
### Test
self.test_episodes = 2 # Number of game played to evaluate the network
def visit_softmax_temperature_fn(self, trained_steps): def visit_softmax_temperature_fn(self, trained_steps):
""" """
...@@ -58,7 +63,7 @@ class MuZeroConfig: ...@@ -58,7 +63,7 @@ class MuZeroConfig:
Returns: Returns:
Positive float. Positive float.
""" """
if trained_steps < 0.25 * self.training_steps: if trained_steps < 0.5 * self.training_steps:
return 1.0 return 1.0
elif trained_steps < 0.75 * self.training_steps: elif trained_steps < 0.75 * self.training_steps:
return 0.5 return 0.5
...@@ -67,7 +72,8 @@ class MuZeroConfig: ...@@ -67,7 +72,8 @@ class MuZeroConfig:
class Game: class Game:
"""Game wrapper. """
Game wrapper.
""" """
def __init__(self, seed=None): def __init__(self, seed=None):
...@@ -76,7 +82,8 @@ class Game: ...@@ -76,7 +82,8 @@ class Game:
self.env.seed(seed) self.env.seed(seed)
def step(self, action): def step(self, action):
"""Apply action to the game. """
Apply action to the game.
Args: Args:
action : action of the action_space to take. action : action of the action_space to take.
...@@ -88,7 +95,8 @@ class Game: ...@@ -88,7 +95,8 @@ class Game:
return numpy.array(observation).flatten(), reward, done return numpy.array(observation).flatten(), reward, done
def reset(self): def reset(self):
"""Reset the game for a new game. """
Reset the game for a new game.
Returns: Returns:
Initial observation of the game. Initial observation of the game.
...@@ -96,12 +104,14 @@ class Game: ...@@ -96,12 +104,14 @@ class Game:
return self.env.reset() return self.env.reset()
def close(self): def close(self):
"""Properly close the game. """
Properly close the game.
""" """
self.env.close() self.env.close()
def render(self): def render(self):
"""Display the game observation. """
Display the game observation.
""" """
self.env.render() self.env.render()
input("Press enter to take a step ") input("Press enter to take a step ")
...@@ -7,16 +7,18 @@ class MuZeroConfig: ...@@ -7,16 +7,18 @@ class MuZeroConfig:
def __init__(self): def __init__(self):
self.seed = 0 # Seed for numpy, torch and the game self.seed = 0 # Seed for numpy, torch and the game
### Game ### Game
self.observation_shape = 8 # Dimensions of the game observation self.observation_shape = 8 # Dimensions of the game observation
self.action_space = [i for i in range(4)] # Fixed list of all possible actions self.action_space = [i for i in range(4)] # Fixed list of all possible actions (float between 0 and 1)
### Self-Play ### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 500 # Maximum number of moves if game is not finished before self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated self.num_simulations = 80 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = None # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A) self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
# Root prior exploration noise # Root prior exploration noise
self.root_dirichlet_alpha = 0.25 self.root_dirichlet_alpha = 0.25
...@@ -26,30 +28,34 @@ class MuZeroConfig: ...@@ -26,30 +28,34 @@ class MuZeroConfig:
self.pb_c_base = 19652 self.pb_c_base = 19652
self.pb_c_init = 1.25 self.pb_c_init = 1.25
### Network ### Network
self.encoding_size = 64 self.encoding_size = 32
self.hidden_size = 32 self.hidden_size = 64
# Training ### Training
self.results_path = "./pretrained" # Path to store the model weights self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 2000 # Total number of training steps (ie weights update according to a batch) self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 8 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A) self.training_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
self.weight_decay = 1e-4 # L2 weights regularization self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9 self.momentum = 0.9
# Test # Exponential learning rate schedule
self.lr_init = 0.01 # Initial learning rate
self.lr_decay_rate = 0.005
self.lr_decay_steps = 10000
### Test
self.test_episodes = 2 # Number of game played to evaluate the network self.test_episodes = 2 # Number of game played to evaluate the network
# Exponential learning rate schedule
self.lr_init = 0.0001 # Initial learning rate
self.lr_decay_rate = 0.1
self.lr_decay_steps = 3500
def visit_softmax_temperature_fn(self, trained_steps): def visit_softmax_temperature_fn(self, trained_steps):
""" """
...@@ -59,7 +65,7 @@ class MuZeroConfig: ...@@ -59,7 +65,7 @@ class MuZeroConfig:
Returns: Returns:
Positive float. Positive float.
""" """
if trained_steps < 0.25 * self.training_steps: if trained_steps < 0.5 * self.training_steps:
return 1.0 return 1.0
elif trained_steps < 0.75 * self.training_steps: elif trained_steps < 0.75 * self.training_steps:
return 0.5 return 0.5
...@@ -68,7 +74,8 @@ class MuZeroConfig: ...@@ -68,7 +74,8 @@ class MuZeroConfig:
class Game: class Game:
"""Game wrapper. """
Game wrapper.
""" """
def __init__(self, seed=None): def __init__(self, seed=None):
...@@ -77,7 +84,8 @@ class Game: ...@@ -77,7 +84,8 @@ class Game:
self.env.seed(seed) self.env.seed(seed)
def step(self, action): def step(self, action):
"""Apply action to the game. """
Apply action to the game.
Args: Args:
action : action of the action_space to take. action : action of the action_space to take.
...@@ -89,7 +97,8 @@ class Game: ...@@ -89,7 +97,8 @@ class Game:
return numpy.array(observation).flatten(), reward, done return numpy.array(observation).flatten(), reward, done
def reset(self): def reset(self):
"""Reset the game for a new game. """
Reset the game for a new game.
Returns: Returns:
Initial observation of the game. Initial observation of the game.
...@@ -97,12 +106,14 @@ class Game: ...@@ -97,12 +106,14 @@ class Game:
return self.env.reset() return self.env.reset()
def close(self): def close(self):
"""Properly close the game. """
Properly close the game.
""" """
self.env.close() self.env.close()
def render(self): def render(self):
"""Display the game observation. """
Display the game observation.
""" """
self.env.render() self.env.render()
input("Press enter to take a step ") input("Press enter to take a step ")
...@@ -40,6 +40,10 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -40,6 +40,10 @@ class MuZeroNetwork(torch.nn.Module):
self.dynamics_encoded_state_network = FullyConnectedNetwork( self.dynamics_encoded_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], encoding_size encoding_size + self.action_space_size, [hidden_size], encoding_size
) )
# Gradient scaling (See paper appendix Training)
self.dynamics_encoded_state_network.register_backward_hook(
lambda module, grad_i, grad_o: (grad_i[0] * 0.5,)
)
self.dynamics_reward_network = FullyConnectedNetwork( self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], 1 encoding_size + self.action_space_size, [hidden_size], 1
) )
...@@ -60,20 +64,25 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -60,20 +64,25 @@ class MuZeroNetwork(torch.nn.Module):
return self.representation_network(observation) return self.representation_network(observation)
def dynamics(self, encoded_state, action): def dynamics(self, encoded_state, action):
# Stack encoded_state with one hot action (See paper appendix Network Architecture)
action_one_hot = ( action_one_hot = (
torch.zeros((action.shape[0], self.action_space_size)) torch.zeros((action.shape[0], self.action_space_size))
.to(action.device) .to(action.device)
.float() .float()
) )
action_one_hot.scatter_(1, action.long(), 1.0) action_one_hot.scatter_(1, action.long(), 1.0)
x = torch.cat((encoded_state, action_one_hot), dim=1) x = torch.cat((encoded_state, action_one_hot), dim=1)
next_encoded_state = self.dynamics_encoded_state_network(x) next_encoded_state = self.dynamics_encoded_state_network(x)
reward = self.dynamics_reward_network(x) reward = self.dynamics_reward_network(x)
return next_encoded_state, reward return next_encoded_state, reward
def initial_inference(self, observation): def initial_inference(self, observation):
encoded_state = self.representation(observation) encoded_state = self.representation(observation)
# Scale encoded state between [0, 1] (See paper Training appendix)
encoded_state = (encoded_state - torch.min(encoded_state)) / (
torch.max(encoded_state) - torch.min(encoded_state)
)
policy_logit, value = self.prediction(encoded_state) policy_logit, value = self.prediction(encoded_state)
return ( return (
value, value,
...@@ -84,6 +93,10 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -84,6 +93,10 @@ class MuZeroNetwork(torch.nn.Module):
def recurrent_inference(self, encoded_state, action): def recurrent_inference(self, encoded_state, action):
next_encoded_state, reward = self.dynamics(encoded_state, action) next_encoded_state, reward = self.dynamics(encoded_state, action)
# Scale encoded state between [0, 1] (See paper Training appendix)
next_encoded_state = (next_encoded_state - torch.min(next_encoded_state)) / (
torch.max(next_encoded_state) - torch.min(next_encoded_state)
)
policy_logit, value = self.prediction(next_encoded_state) policy_logit, value = self.prediction(next_encoded_state)
return value, reward, policy_logit, next_encoded_state return value, reward, policy_logit, next_encoded_state
......
...@@ -169,5 +169,5 @@ if __name__ == "__main__": ...@@ -169,5 +169,5 @@ if __name__ == "__main__":
muzero = MuZero("cartpole") muzero = MuZero("cartpole")
muzero.train() muzero.train()
muzero.load_model() # muzero.load_model()
muzero.test() muzero.test()
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"nbformat_minor": 0, "nbformat_minor": 0,
"metadata": { "metadata": {
"colab": { "colab": {
"name": "Untitled3.ipynb", "name": "muzero.ipynb",
"provenance": [] "provenance": []
}, },
"kernelspec": { "kernelspec": {
...@@ -50,4 +50,4 @@ ...@@ -50,4 +50,4 @@
] ]
} }
] ]
} }
\ No newline at end of file
No preview for this file type
No preview for this file type
...@@ -36,7 +36,7 @@ class ReplayBuffer: ...@@ -36,7 +36,7 @@ class ReplayBuffer:
actions = game_history.history[ actions = game_history.history[
game_pos : game_pos + self.config.num_unroll_steps game_pos : game_pos + self.config.num_unroll_steps
] ]
# Repeat precedent action to make "actions" of length "num_unroll_steps" # Repeat precedent action to make 'actions' of length 'num_unroll_steps'
actions.extend( actions.extend(
[ [
actions[-1] actions[-1]
...@@ -71,7 +71,7 @@ def sample_position(game_history): ...@@ -71,7 +71,7 @@ def sample_position(game_history):
""" """
Sample position from game either uniformly or according to some priority. Sample position from game either uniformly or according to some priority.
""" """
# TODO: according to some priority # TODO: sample according to some priority
return numpy.random.choice(range(len(game_history.rewards))) return numpy.random.choice(range(len(game_history.rewards)))
...@@ -95,7 +95,7 @@ def make_target(game_history, state_index, num_unroll_steps, td_steps): ...@@ -95,7 +95,7 @@ def make_target(game_history, state_index, num_unroll_steps, td_steps):
value += reward * game_history.discount ** i value += reward * game_history.discount ** i
if current_index < len(game_history.root_values): if current_index < len(game_history.root_values):
target_values.append(value) target_values.append(0.25 * value)
target_rewards.append(game_history.rewards[current_index]) target_rewards.append(game_history.rewards[current_index])
target_policies.append(game_history.child_visits[current_index]) target_policies.append(game_history.child_visits[current_index])
else: else:
......
...@@ -68,7 +68,7 @@ class SelfPlay: ...@@ -68,7 +68,7 @@ class SelfPlay:
game_history.observation_history.append(observation) game_history.observation_history.append(observation)
done = False done = False
while not done and len(game_history.history) < self.config.max_moves: while not done and len(game_history.history) < self.config.max_moves:
root = MCTS(self.config).run(self.model, observation, True) root = MCTS(self.config).run(self.model, observation, True if temperature else False)
action = select_action(root, temperature) action = select_action(root, temperature)
...@@ -76,7 +76,6 @@ class SelfPlay: ...@@ -76,7 +76,6 @@ class SelfPlay:
if render: if render:
self.game.render() self.game.render()
print("Press enter to step")
game_history.observation_history.append(observation) game_history.observation_history.append(observation)
game_history.rewards.append(reward) game_history.rewards.append(reward)
...@@ -209,8 +208,6 @@ class MCTS: ...@@ -209,8 +208,6 @@ class MCTS:
to the root. to the root.
""" """
for node in search_path: for node in search_path:
# Always the same player, the other players minds should be modeled in network
# because environment do not act always in the best way to make you lose
node.value_sum += value # if node.to_play == to_play else -value node.value_sum += value # if node.to_play == to_play else -value
node.visit_count += 1 node.visit_count += 1
min_max_stats.update(node.value()) min_max_stats.update(node.value())
......
...@@ -116,10 +116,12 @@ class Trainer: ...@@ -116,10 +116,12 @@ class Trainer:
reward_loss += current_reward_loss reward_loss += current_reward_loss
policy_loss += current_policy_loss policy_loss += current_policy_loss
# Scale gradient by number of unroll steps (See paper Training appendix)
loss = ( loss = (
value_loss + reward_loss + policy_loss value_loss + reward_loss + policy_loss
).mean() / self.config.num_unroll_steps ).mean()
# Scale gradient by number of unroll steps (See paper Training appendix)
loss.register_hook(lambda grad: grad * 1 / self.config.num_unroll_steps)
# Optimize # Optimize
self.optimizer.zero_grad() self.optimizer.zero_grad()
...@@ -139,7 +141,7 @@ def loss_function( ...@@ -139,7 +141,7 @@ def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy value, reward, policy_logits, target_value, target_reward, target_policy
): ):
# TODO: paper promotes cross entropy instead of MSE # TODO: paper promotes cross entropy instead of MSE
value_loss = torch.nn.MSELoss(reduction="none")(value, target_value) value_loss = torch.nn.MSELoss()(value, target_value)
reward_loss = torch.nn.MSELoss(reduction="none")(reward, target_reward) reward_loss = torch.nn.MSELoss()(reward, target_reward)
policy_loss = -(torch.log_softmax(policy_logits, dim=1) * target_policy).sum(1) policy_loss = torch.mean(torch.sum(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits), 1))
return value_loss, reward_loss, policy_loss return value_loss, reward_loss, policy_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment