Commit c5f6bffc by Werner Duvaud

Add gradient scaling and update hyperparameters

parent ceb98db4
......@@ -6,16 +6,18 @@ class MuZeroConfig:
def __init__(self):
self.seed = 0 # Seed for numpy, torch and the game
### Game
self.observation_shape = 4 # Dimensions of the game observation
self.action_space = [i for i in range(2)] # Fixed list of all possible actions
self.action_space = [i for i in range(2)] # Fixed list of all possible actions (float between 0 and 1)
### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated
self.num_simulations = 80 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = None # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A)
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
# Root prior exploration noise
self.root_dirichlet_alpha = 0.25
......@@ -25,30 +27,33 @@ class MuZeroConfig:
self.pb_c_base = 19652
self.pb_c_init = 1.25
### Network
self.encoding_size = 64
self.hidden_size = 32
self.encoding_size = 32
self.hidden_size = 64
# Training
### Training
self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 2000 # Total number of training steps (ie weights update according to a batch)
self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 1 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A)
self.training_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9
# Test
self.test_episodes = 2 # Number of game played to evaluate the network
# Exponential learning rate schedule
self.lr_init = 0.0005 # Initial learning rate
self.lr_decay_rate = 0.1
self.lr_decay_steps = 3500
self.lr_init = 0.008 # Initial learning rate
self.lr_decay_rate = 0.01
self.lr_decay_steps = 10000
### Test
self.test_episodes = 2 # Number of game played to evaluate the network
def visit_softmax_temperature_fn(self, trained_steps):
"""
......@@ -58,7 +63,7 @@ class MuZeroConfig:
Returns:
Positive float.
"""
if trained_steps < 0.25 * self.training_steps:
if trained_steps < 0.5 * self.training_steps:
return 1.0
elif trained_steps < 0.75 * self.training_steps:
return 0.5
......@@ -67,7 +72,8 @@ class MuZeroConfig:
class Game:
"""Game wrapper.
"""
Game wrapper.
"""
def __init__(self, seed=None):
......@@ -76,7 +82,8 @@ class Game:
self.env.seed(seed)
def step(self, action):
"""Apply action to the game.
"""
Apply action to the game.
Args:
action : action of the action_space to take.
......@@ -88,7 +95,8 @@ class Game:
return numpy.array(observation).flatten(), reward, done
def reset(self):
"""Reset the game for a new game.
"""
Reset the game for a new game.
Returns:
Initial observation of the game.
......@@ -96,12 +104,14 @@ class Game:
return self.env.reset()
def close(self):
"""Properly close the game.
"""
Properly close the game.
"""
self.env.close()
def render(self):
"""Display the game observation.
"""
Display the game observation.
"""
self.env.render()
input("Press enter to take a step ")
......@@ -7,16 +7,18 @@ class MuZeroConfig:
def __init__(self):
self.seed = 0 # Seed for numpy, torch and the game
### Game
self.observation_shape = 8 # Dimensions of the game observation
self.action_space = [i for i in range(4)] # Fixed list of all possible actions
self.action_space = [i for i in range(4)] # Fixed list of all possible actions (float between 0 and 1)
### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated
self.num_simulations = 80 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = None # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A)
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
# Root prior exploration noise
self.root_dirichlet_alpha = 0.25
......@@ -26,30 +28,34 @@ class MuZeroConfig:
self.pb_c_base = 19652
self.pb_c_init = 1.25
### Network
self.encoding_size = 64
self.hidden_size = 32
self.encoding_size = 32
self.hidden_size = 64
# Training
### Training
self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 2000 # Total number of training steps (ie weights update according to a batch)
self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 8 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A)
self.training_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9
# Test
# Exponential learning rate schedule
self.lr_init = 0.01 # Initial learning rate
self.lr_decay_rate = 0.005
self.lr_decay_steps = 10000
### Test
self.test_episodes = 2 # Number of game played to evaluate the network
# Exponential learning rate schedule
self.lr_init = 0.0001 # Initial learning rate
self.lr_decay_rate = 0.1
self.lr_decay_steps = 3500
def visit_softmax_temperature_fn(self, trained_steps):
"""
......@@ -59,7 +65,7 @@ class MuZeroConfig:
Returns:
Positive float.
"""
if trained_steps < 0.25 * self.training_steps:
if trained_steps < 0.5 * self.training_steps:
return 1.0
elif trained_steps < 0.75 * self.training_steps:
return 0.5
......@@ -68,7 +74,8 @@ class MuZeroConfig:
class Game:
"""Game wrapper.
"""
Game wrapper.
"""
def __init__(self, seed=None):
......@@ -77,7 +84,8 @@ class Game:
self.env.seed(seed)
def step(self, action):
"""Apply action to the game.
"""
Apply action to the game.
Args:
action : action of the action_space to take.
......@@ -89,7 +97,8 @@ class Game:
return numpy.array(observation).flatten(), reward, done
def reset(self):
"""Reset the game for a new game.
"""
Reset the game for a new game.
Returns:
Initial observation of the game.
......@@ -97,12 +106,14 @@ class Game:
return self.env.reset()
def close(self):
"""Properly close the game.
"""
Properly close the game.
"""
self.env.close()
def render(self):
"""Display the game observation.
"""
Display the game observation.
"""
self.env.render()
input("Press enter to take a step ")
......@@ -40,6 +40,10 @@ class MuZeroNetwork(torch.nn.Module):
self.dynamics_encoded_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], encoding_size
)
# Gradient scaling (See paper appendix Training)
self.dynamics_encoded_state_network.register_backward_hook(
lambda module, grad_i, grad_o: (grad_i[0] * 0.5,)
)
self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], 1
)
......@@ -60,20 +64,25 @@ class MuZeroNetwork(torch.nn.Module):
return self.representation_network(observation)
def dynamics(self, encoded_state, action):
# Stack encoded_state with one hot action (See paper appendix Network Architecture)
action_one_hot = (
torch.zeros((action.shape[0], self.action_space_size))
.to(action.device)
.float()
)
action_one_hot.scatter_(1, action.long(), 1.0)
x = torch.cat((encoded_state, action_one_hot), dim=1)
next_encoded_state = self.dynamics_encoded_state_network(x)
reward = self.dynamics_reward_network(x)
return next_encoded_state, reward
def initial_inference(self, observation):
encoded_state = self.representation(observation)
# Scale encoded state between [0, 1] (See paper Training appendix)
encoded_state = (encoded_state - torch.min(encoded_state)) / (
torch.max(encoded_state) - torch.min(encoded_state)
)
policy_logit, value = self.prediction(encoded_state)
return (
value,
......@@ -84,6 +93,10 @@ class MuZeroNetwork(torch.nn.Module):
def recurrent_inference(self, encoded_state, action):
next_encoded_state, reward = self.dynamics(encoded_state, action)
# Scale encoded state between [0, 1] (See paper Training appendix)
next_encoded_state = (next_encoded_state - torch.min(next_encoded_state)) / (
torch.max(next_encoded_state) - torch.min(next_encoded_state)
)
policy_logit, value = self.prediction(next_encoded_state)
return value, reward, policy_logit, next_encoded_state
......
......@@ -169,5 +169,5 @@ if __name__ == "__main__":
muzero = MuZero("cartpole")
muzero.train()
muzero.load_model()
# muzero.load_model()
muzero.test()
......@@ -3,7 +3,7 @@
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Untitled3.ipynb",
"name": "muzero.ipynb",
"provenance": []
},
"kernelspec": {
......
No preview for this file type
No preview for this file type
......@@ -36,7 +36,7 @@ class ReplayBuffer:
actions = game_history.history[
game_pos : game_pos + self.config.num_unroll_steps
]
# Repeat precedent action to make "actions" of length "num_unroll_steps"
# Repeat precedent action to make 'actions' of length 'num_unroll_steps'
actions.extend(
[
actions[-1]
......@@ -71,7 +71,7 @@ def sample_position(game_history):
"""
Sample position from game either uniformly or according to some priority.
"""
# TODO: according to some priority
# TODO: sample according to some priority
return numpy.random.choice(range(len(game_history.rewards)))
......@@ -95,7 +95,7 @@ def make_target(game_history, state_index, num_unroll_steps, td_steps):
value += reward * game_history.discount ** i
if current_index < len(game_history.root_values):
target_values.append(value)
target_values.append(0.25 * value)
target_rewards.append(game_history.rewards[current_index])
target_policies.append(game_history.child_visits[current_index])
else:
......
......@@ -68,7 +68,7 @@ class SelfPlay:
game_history.observation_history.append(observation)
done = False
while not done and len(game_history.history) < self.config.max_moves:
root = MCTS(self.config).run(self.model, observation, True)
root = MCTS(self.config).run(self.model, observation, True if temperature else False)
action = select_action(root, temperature)
......@@ -76,7 +76,6 @@ class SelfPlay:
if render:
self.game.render()
print("Press enter to step")
game_history.observation_history.append(observation)
game_history.rewards.append(reward)
......@@ -209,8 +208,6 @@ class MCTS:
to the root.
"""
for node in search_path:
# Always the same player, the other players minds should be modeled in network
# because environment do not act always in the best way to make you lose
node.value_sum += value # if node.to_play == to_play else -value
node.visit_count += 1
min_max_stats.update(node.value())
......
......@@ -116,10 +116,12 @@ class Trainer:
reward_loss += current_reward_loss
policy_loss += current_policy_loss
# Scale gradient by number of unroll steps (See paper Training appendix)
loss = (
value_loss + reward_loss + policy_loss
).mean() / self.config.num_unroll_steps
).mean()
# Scale gradient by number of unroll steps (See paper Training appendix)
loss.register_hook(lambda grad: grad * 1 / self.config.num_unroll_steps)
# Optimize
self.optimizer.zero_grad()
......@@ -139,7 +141,7 @@ def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy
):
# TODO: paper promotes cross entropy instead of MSE
value_loss = torch.nn.MSELoss(reduction="none")(value, target_value)
reward_loss = torch.nn.MSELoss(reduction="none")(reward, target_reward)
policy_loss = -(torch.log_softmax(policy_logits, dim=1) * target_policy).sum(1)
value_loss = torch.nn.MSELoss()(value, target_value)
reward_loss = torch.nn.MSELoss()(reward, target_reward)
policy_loss = torch.mean(torch.sum(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits), 1))
return value_loss, reward_loss, policy_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment