Commit fd660a25 by Werner Duvaud

Improve performance

parent 3967fb57
...@@ -25,12 +25,7 @@ MuZero is a model based reinforcement learning algorithm, successor of AlphaZero ...@@ -25,12 +25,7 @@ MuZero is a model based reinforcement learning algorithm, successor of AlphaZero
* [ ] Play against MuZero mode with policy and value tracking * [ ] Play against MuZero mode with policy and value tracking
* [ ] Residual Network * [ ] Residual Network
* [ ] Atari games * [ ] Atari games
* [ ] Windows support ([workaround by ihexx](https://github.com/ihexx/muzero-general))
## Games already implemented with pretrained network available
* Cartpole
* Lunar Lander
* Connect4
## Demo ## Demo
...@@ -42,6 +37,16 @@ Testing Lunar Lander : ...@@ -42,6 +37,16 @@ Testing Lunar Lander :
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/docs/lunarlander_training_preview.png) ![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/docs/lunarlander_training_preview.png)
## Code structure
![code structure](https://github.com/werner-duvaud/muzero-general/blob/master/docs/how-it-works-werner-duvaud.png)
## Games already implemented with pretrained network available
* Cartpole
* Lunar Lander
* Connect4
## Getting started ## Getting started
### Installation ### Installation
......
...@@ -37,12 +37,12 @@ class MuZeroConfig: ...@@ -37,12 +37,12 @@ class MuZeroConfig:
### Training ### Training
self.results_path = "./pretrained" # Path to store the model weights self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch) self.training_steps = 5000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value self.td_steps = 30 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting
self.training_device = "cuda" if torch.cuda.is_available() else "cpu" # Train on GPU if available self.training_device = "cuda" if torch.cuda.is_available() else "cpu" # Train on GPU if available
...@@ -51,7 +51,7 @@ class MuZeroConfig: ...@@ -51,7 +51,7 @@ class MuZeroConfig:
# Exponential learning rate schedule # Exponential learning rate schedule
self.lr_init = 0.008 # Initial learning rate self.lr_init = 0.008 # Initial learning rate
self.lr_decay_rate = 0.01 self.lr_decay_rate = 1
self.lr_decay_steps = 10000 self.lr_decay_steps = 10000
......
...@@ -51,7 +51,7 @@ class MuZeroConfig: ...@@ -51,7 +51,7 @@ class MuZeroConfig:
# Exponential learning rate schedule # Exponential learning rate schedule
self.lr_init = 0.05 # Initial learning rate self.lr_init = 0.05 # Initial learning rate
self.lr_decay_rate = 0.01 self.lr_decay_rate = 1
self.lr_decay_steps = 10000 self.lr_decay_steps = 10000
......
...@@ -16,7 +16,7 @@ class MuZeroConfig: ...@@ -16,7 +16,7 @@ class MuZeroConfig:
### Self-Play ### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 200 # Maximum number of moves if game is not finished before self.max_moves = 1000 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
...@@ -31,16 +31,16 @@ class MuZeroConfig: ...@@ -31,16 +31,16 @@ class MuZeroConfig:
### Network ### Network
self.encoding_size = 32 self.encoding_size = 64
self.hidden_size = 64 self.hidden_size = 128
### Training ### Training
self.results_path = "./pretrained" # Path to store the model weights self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 20000 # Total number of training steps (ie weights update according to a batch) self.training_steps = 3000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element self.num_unroll_steps = 10 # Number of game moves to keep for every batch element
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing self.checkpoint_interval = 3 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting
...@@ -50,9 +50,9 @@ class MuZeroConfig: ...@@ -50,9 +50,9 @@ class MuZeroConfig:
self.momentum = 0.9 self.momentum = 0.9
# Exponential learning rate schedule # Exponential learning rate schedule
self.lr_init = 0.01 # Initial learning rate self.lr_init = 0.00005 # Initial learning rate
self.lr_decay_rate = 0.001 self.lr_decay_rate = 1
self.lr_decay_steps = 10000 self.lr_decay_steps = 100000
### Test ### Test
...@@ -67,12 +67,16 @@ class MuZeroConfig: ...@@ -67,12 +67,16 @@ class MuZeroConfig:
Returns: Returns:
Positive float. Positive float.
""" """
if trained_steps < 0.5 * self.training_steps: # if trained_steps < 0.2 * self.training_steps:
return 1.0 # return float('inf')
elif trained_steps < 0.75 * self.training_steps: # if trained_steps < 0.5 * self.training_steps:
return 0.5 # return 0.8
else: # elif trained_steps < 0.75 * self.training_steps:
return 0.25 # return 0.5
# else:
# return 0.25
return 1
class Game: class Game:
......
...@@ -45,7 +45,7 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -45,7 +45,7 @@ class MuZeroNetwork(torch.nn.Module):
lambda module, grad_i, grad_o: (grad_i[0] * 0.5,) lambda module, grad_i, grad_o: (grad_i[0] * 0.5,)
) )
self.dynamics_reward_network = FullyConnectedNetwork( self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], 1 encoding_size + self.action_space_size, [hidden_size], 1, activation=None
) )
self.prediction_policy_network = FullyConnectedNetwork( self.prediction_policy_network = FullyConnectedNetwork(
......
...@@ -134,9 +134,9 @@ class MuZero: ...@@ -134,9 +134,9 @@ class MuZero:
) )
counter += 1 counter += 1
time.sleep(3) time.sleep(3)
except KeyboardInterrupt: except KeyboardInterrupt as err:
# Comment the line below to be able to stop the training but keep running # Comment the line below to be able to stop the training but keep running
raise KeyboardInterrupt raise err
pass pass
self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote()) self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
ray.shutdown() ray.shutdown()
......
No preview for this file type
No preview for this file type
No preview for this file type
...@@ -73,7 +73,7 @@ class SelfPlay: ...@@ -73,7 +73,7 @@ class SelfPlay:
self.model, self.model,
observation, observation,
self.game.to_play(), self.game.to_play(),
True if temperature else False, False if temperature == 0 else True,
) )
action = select_action(root, temperature) action = select_action(root, temperature)
...@@ -103,6 +103,8 @@ def select_action(node, temperature): ...@@ -103,6 +103,8 @@ def select_action(node, temperature):
).T ).T
if temperature == 0: if temperature == 0:
action_pos = numpy.argmax(visit_counts[0]) action_pos = numpy.argmax(visit_counts[0])
elif temperature == float("inf"):
action_pos = numpy.random.choice(len(visit_counts[1]))
else: else:
# See paper appendix Data Generation # See paper appendix Data Generation
visit_count_distribution = visit_counts[0] ** (1 / temperature) visit_count_distribution = visit_counts[0] ** (1 / temperature)
...@@ -113,9 +115,6 @@ def select_action(node, temperature): ...@@ -113,9 +115,6 @@ def select_action(node, temperature):
len(visit_counts[1]), p=visit_count_distribution len(visit_counts[1]), p=visit_count_distribution
) )
if temperature == float("inf"):
action_pos = numpy.random.choice(len(visit_counts[1]))
return visit_counts[1][action_pos] return visit_counts[1][action_pos]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment