Commit 4b81586b by Werner Duvaud

Add multiplayer mode

parent c9059c39
...@@ -6,31 +6,52 @@ ...@@ -6,31 +6,52 @@
# MuZero General # MuZero General
A flexible, commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v1/anc/pseudocode.py). A commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v1/anc/pseudocode.py).
It is designed to be easily adaptable for every games or reinforcement learning environments (like [gym](https://github.com/openai/gym)). You only need to edit the [game file](https://github.com/werner-duvaud/muzero-general/tree/master/games) with the parameters and the game class. Please refer to the documentation and the [example](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py). It is designed to be easily adaptable for every games or reinforcement learning environments (like [gym](https://github.com/openai/gym)). You only need to edit the [game file](https://github.com/werner-duvaud/muzero-general/tree/master/games) with the parameters and the game class. Please refer to the [documentation](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) and the [example](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py).
MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games without knowing the rules. It only knows actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122). MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games without knowing the rules. It only knows actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122). See [How it works](https://github.com/werner-duvaud/muzero-general/wiki/How-MuZero-works)
It uses [PyTorch](https://github.com/pytorch/pytorch) and [Ray](https://github.com/ray-project/ray) for running the different components simultaneously. GPU training is supported. See [How it works](https://github.com/werner-duvaud/muzero-general/wiki/How-MuZero-works) ## Features
All performances are tracked and displayed in real time in tensorboard. * [x] Fully connected network in [PyTorch](https://github.com/pytorch/pytorch)
* [x] Multi-Threaded with [Ray](https://github.com/ray-project/ray)
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/docs/cartpole_training_summary.png) * [x] CPU/GPU support
* [x] TensorBoard real-time monitoring
* [x] Single and multiplayer mode
* [x] Commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation)
* [x] Easily adaptable for new games
* [x] [Examples](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py) of board and Gym games (See [list below](https://github.com/werner-duvaud/muzero-general#games-already-implemented-with-pretrained-network-available))
* [x] [Pretrained weights](https://github.com/werner-duvaud/muzero-general/tree/master/pretrained) available
* [ ] Play against MuZero mode with policy and value tracking
* [ ] Residual Network
* [ ] Atari games
## Games already implemented with pretrained network available ## Games already implemented with pretrained network available
* Lunar Lander
* Cartpole * Cartpole
* Lunar Lander
* Connect4
## Demo
All performances are tracked and displayed in real time in tensorboard :
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/docs/cartpole_training_summary.png)
Testing Lunar Lander :
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/docs/lunarlander_training_preview.png) ![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/docs/lunarlander_training_preview.png)
## Getting started ## Getting started
### Installation ### Installation
```bash ```bash
cd muzero-general cd muzero-general
pip install -r requirements.txt pip install -r requirements.txt
``` ```
### Training ### Training
Edit the end of muzero.py: Edit the end of muzero.py:
```python ```python
muzero = Muzero("cartpole") muzero = Muzero("cartpole")
...@@ -40,12 +61,13 @@ Then run: ...@@ -40,12 +61,13 @@ Then run:
```bash ```bash
python muzero.py python muzero.py
``` ```
To visualize the training results, run in a new bash: To visualize the training results, run in a new terminal:
```bash ```bash
tensorboard --logdir ./ tensorboard --logdir ./
``` ```
### Testing ### Testing
Edit the end of muzero.py: Edit the end of muzero.py:
```python ```python
muzero = Muzero("cartpole") muzero = Muzero("cartpole")
...@@ -57,14 +79,8 @@ Then run: ...@@ -57,14 +79,8 @@ Then run:
python muzero.py python muzero.py
``` ```
## Coming soon
* [ ] Atari mode with residual network
* [ ] Live test policy & value tracking
* [ ] [Open spiel](https://github.com/deepmind/open_spiel) integration
* [ ] Checkers game
* [ ] TensorFlow mode
## Authors ## Authors
* Werner Duvaud * Werner Duvaud
* Aurèle Hainaut * Aurèle Hainaut
* Paul Lenoir * Paul Lenoir
...@@ -11,6 +11,7 @@ class MuZeroConfig: ...@@ -11,6 +11,7 @@ class MuZeroConfig:
### Game ### Game
self.observation_shape = 4 # Dimensions of the game observation self.observation_shape = 4 # Dimensions of the game observation
self.action_space = [i for i in range(2)] # Fixed list of all possible actions self.action_space = [i for i in range(2)] # Fixed list of all possible actions
self.players = [i for i in range(1)] # List of players
### Self-Play ### Self-Play
...@@ -95,7 +96,16 @@ class Game: ...@@ -95,7 +96,16 @@ class Game:
The new observation, the reward and a boolean if the game has ended. The new observation, the reward and a boolean if the game has ended.
""" """
observation, reward, done, _ = self.env.step(action) observation, reward, done, _ = self.env.step(action)
return numpy.array(observation).flatten(), reward, done return observation, reward, done
def to_play(self):
"""
Return the current player.
Returns:
The current player, it should be an element of the players list in the config.
"""
return 0
def reset(self): def reset(self):
""" """
......
import gym
import numpy
import torch
class MuZeroConfig:
def __init__(self):
self.seed = 0 # Seed for numpy, torch and the game
### Game
self.observation_shape = 6 * 7 # Dimensions of the game observation
self.action_space = [i for i in range(7)] # Fixed list of all possible actions
self.players = [i for i in range(2)] # List of players
### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 50 # Maximum number of moves if game is not finished before
self.num_simulations = 30 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
# Root prior exploration noise
self.root_dirichlet_alpha = 0.25
self.root_exploration_fraction = 0.25
# UCB formula
self.pb_c_base = 19652
self.pb_c_init = 1.25
### Network
self.encoding_size = 32
self.hidden_size = 64
### Training
self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting
self.training_device = "cuda" if torch.cuda.is_available() else "cpu" # Train on GPU if available
self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9
# Exponential learning rate schedule
self.lr_init = 0.05 # Initial learning rate
self.lr_decay_rate = 0.01
self.lr_decay_steps = 10000
### Test
self.test_episodes = 2 # Number of game played to evaluate the network
def visit_softmax_temperature_fn(self, trained_steps):
"""
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
The smaller it is, the more likely the best action (ie with the highest visit count) is chosen.
Returns:
Positive float.
"""
if trained_steps < 0.5 * self.training_steps:
return 1.0
elif trained_steps < 0.75 * self.training_steps:
return 0.5
else:
return 0.25
class Game:
"""
Game wrapper.
"""
def __init__(self, seed=None):
self.env = Connect4()
def step(self, action):
"""
Apply action to the game.
Args:
action : action of the action_space to take.
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
if action not in self.env.legal_actions():
observation, reward, done = self.env.step(self.env.legal_actions()[0])
reward = -1
done = True
else:
observation, reward, done = self.env.step(action)
return numpy.array(observation).flatten(), reward, done
def to_play(self):
"""
Return the current player.
Returns:
The current player, it should be an element of the players list in the config.
"""
return self.env.to_play()
def reset(self):
"""
Reset the game for a new game.
Returns:
Initial observation of the game.
"""
return numpy.array(self.env.reset()).flatten()
def close(self):
"""
Properly close the game.
"""
pass
def render(self):
"""
Display the game observation.
"""
self.env.render()
input("Press enter to take a step ")
class Connect4:
def __init__(self):
self.board = numpy.zeros((6, 7)).astype(int)
self.player = 1
def to_play(self):
return 0 if self.player == 1 else 1
def reset(self):
self.board = numpy.zeros((6, 7)).astype(int)
self.player = 1
return self.get_observation()
def step(self, action):
for i in range(6):
if self.board[i][action] == 0:
self.board[i][action] = self.player
break
done = self.is_finished()
self.player *= -1
return self.get_observation(), 1 if done else 0, done
def get_observation(self):
if self.player == 1:
return self.board
else:
return -self.board
def legal_actions(self):
legal = []
for i in range(7):
for j in range(6):
if self.board[j][i] == 0:
legal.append(i)
break
return legal
def is_finished(self):
# Horizontal check
for i in range(4):
for j in range(6):
if (
self.board[j][i] == self.player
and self.board[j][i + 1] == self.player
and self.board[j][i + 2] == self.player
and self.board[j][i + 3] == self.player
):
return True
# Vertical check
for i in range(7):
for j in range(3):
if (
self.board[j][i] == self.player
and self.board[j + 1][i] == self.player
and self.board[j + 2][i] == self.player
and self.board[j + 3][i] == self.player
):
return True
# x diag check
for i in range(4):
for j in range(3):
if (
self.board[j][i] == self.player
and self.board[j + 1][i + 1] == self.player
and self.board[j + 2][i + 2] == self.player
and self.board[j + 3][i + 3] == self.player
):
return True
# -x diag check
for i in range(4):
for j in range(3, 6):
if (
self.board[j][i] == self.player
and self.board[j - 1][i + 1] == self.player
and self.board[j - 2][i + 2] == self.player
and self.board[j - 3][i + 3] == self.player
):
return True
return False
def render(self):
print(self.player * self.get_observation()[::-1])
...@@ -11,12 +11,13 @@ class MuZeroConfig: ...@@ -11,12 +11,13 @@ class MuZeroConfig:
### Game ### Game
self.observation_shape = 8 # Dimensions of the game observation self.observation_shape = 8 # Dimensions of the game observation
self.action_space = [i for i in range(4)] # Fixed list of all possible actions self.action_space = [i for i in range(4)] # Fixed list of all possible actions
self.players = [i for i in range(1)] # List of players
### Self-Play ### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 500 # Maximum number of moves if game is not finished before self.max_moves = 200 # Maximum number of moves if game is not finished before
self.num_simulations = 80 # Number of futur moves self-simulated self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
...@@ -95,7 +96,16 @@ class Game: ...@@ -95,7 +96,16 @@ class Game:
The new observation, the reward and a boolean if the game has ended. The new observation, the reward and a boolean if the game has ended.
""" """
observation, reward, done, _ = self.env.step(action) observation, reward, done, _ = self.env.step(action)
return numpy.array(observation).flatten(), reward, done return numpy.array(observation).flatten(), reward/100, done
def to_play(self):
"""
Return the current player.
Returns:
The current player, it should be an element of the players list in the config.
"""
return 0
def reset(self): def reset(self):
""" """
......
...@@ -61,7 +61,13 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -61,7 +61,13 @@ class MuZeroNetwork(torch.nn.Module):
return policy_logit, value return policy_logit, value
def representation(self, observation): def representation(self, observation):
return self.representation_network(observation) encoded_state = self.representation_network(observation)
# TODO: Try to remove tanh activation
# Scale encoded state between [0, 1] (See appendix paper Training)
encoded_state = (encoded_state - torch.min(encoded_state)) / (
torch.max(encoded_state) - torch.min(encoded_state)
)
return encoded_state
def dynamics(self, encoded_state, action): def dynamics(self, encoded_state, action):
# Stack encoded_state with one hot action (See paper appendix Network Architecture) # Stack encoded_state with one hot action (See paper appendix Network Architecture)
...@@ -74,15 +80,16 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -74,15 +80,16 @@ class MuZeroNetwork(torch.nn.Module):
x = torch.cat((encoded_state, action_one_hot), dim=1) x = torch.cat((encoded_state, action_one_hot), dim=1)
next_encoded_state = self.dynamics_encoded_state_network(x) next_encoded_state = self.dynamics_encoded_state_network(x)
# TODO: Try to remove tanh activation
# Scale encoded state between [0, 1] (See paper appendix Training)
next_encoded_state = (next_encoded_state - torch.min(next_encoded_state)) / (
torch.max(next_encoded_state) - torch.min(next_encoded_state)
)
reward = self.dynamics_reward_network(x) reward = self.dynamics_reward_network(x)
return next_encoded_state, reward return next_encoded_state, reward
def initial_inference(self, observation): def initial_inference(self, observation):
encoded_state = self.representation(observation) encoded_state = self.representation(observation)
# Scale encoded state between [0, 1] (See paper Training appendix)
encoded_state = (encoded_state - torch.min(encoded_state)) / (
torch.max(encoded_state) - torch.min(encoded_state)
)
policy_logit, value = self.prediction(encoded_state) policy_logit, value = self.prediction(encoded_state)
return ( return (
value, value,
...@@ -93,10 +100,6 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -93,10 +100,6 @@ class MuZeroNetwork(torch.nn.Module):
def recurrent_inference(self, encoded_state, action): def recurrent_inference(self, encoded_state, action):
next_encoded_state, reward = self.dynamics(encoded_state, action) next_encoded_state, reward = self.dynamics(encoded_state, action)
# Scale encoded state between [0, 1] (See paper Training appendix)
next_encoded_state = (next_encoded_state - torch.min(next_encoded_state)) / (
torch.max(next_encoded_state) - torch.min(next_encoded_state)
)
policy_logit, value = self.prediction(next_encoded_state) policy_logit, value = self.prediction(next_encoded_state)
return value, reward, policy_logit, next_encoded_state return value, reward, policy_logit, next_encoded_state
......
...@@ -45,10 +45,7 @@ class MuZero: ...@@ -45,10 +45,7 @@ class MuZero:
) )
raise err raise err
try: os.makedirs(os.path.join(self.config.results_path), exist_ok=True)
os.mkdir(os.path.join(self.config.results_path))
except FileExistsError:
pass
# Fix random generator seed for reproductibility # Fix random generator seed for reproductibility
numpy.random.seed(self.config.seed) numpy.random.seed(self.config.seed)
...@@ -69,9 +66,9 @@ class MuZero: ...@@ -69,9 +66,9 @@ class MuZero:
) )
# Initialize workers # Initialize workers
training_worker = trainer.Trainer.remote( training_worker = trainer.Trainer.options(
copy.deepcopy(self.muzero_weights), self.config num_gpus=1 if "cuda" in self.config.training_device else 0
) ).remote(copy.deepcopy(self.muzero_weights), self.config)
shared_storage_worker = shared_storage.SharedStorage.remote( shared_storage_worker = shared_storage.SharedStorage.remote(
copy.deepcopy(self.muzero_weights), self.game_name, self.config, copy.deepcopy(self.muzero_weights), self.game_name, self.config,
) )
...@@ -106,6 +103,7 @@ class MuZero: ...@@ -106,6 +103,7 @@ class MuZero:
) )
counter = 0 counter = 0
infos = ray.get(shared_storage_worker.get_infos.remote()) infos = ray.get(shared_storage_worker.get_infos.remote())
try:
while infos["training_step"] < self.config.training_steps: while infos["training_step"] < self.config.training_steps:
# Get and save real time performance # Get and save real time performance
infos = ray.get(shared_storage_worker.get_infos.remote()) infos = ray.get(shared_storage_worker.get_infos.remote())
...@@ -136,6 +134,8 @@ class MuZero: ...@@ -136,6 +134,8 @@ class MuZero:
) )
counter += 1 counter += 1
time.sleep(3) time.sleep(3)
except KeyboardInterrupt:
pass
self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote()) self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
ray.shutdown() ray.shutdown()
...@@ -149,7 +149,6 @@ class MuZero: ...@@ -149,7 +149,6 @@ class MuZero:
copy.deepcopy(self.muzero_weights), self.Game(), self.config copy.deepcopy(self.muzero_weights), self.Game(), self.config
) )
test_rewards = [] test_rewards = []
with torch.no_grad():
for _ in range(self.config.test_episodes): for _ in range(self.config.test_episodes):
history = ray.get(self_play_workers.play_game.remote(0, render)) history = ray.get(self_play_workers.play_game.remote(0, render))
test_rewards.append(sum(history.rewards)) test_rewards.append(sum(history.rewards))
...@@ -170,5 +169,5 @@ if __name__ == "__main__": ...@@ -170,5 +169,5 @@ if __name__ == "__main__":
muzero = MuZero("cartpole") muzero = MuZero("cartpole")
muzero.train() muzero.train()
#muzero.load_model() muzero.load_model()
muzero.test() muzero.test()
...@@ -33,24 +33,17 @@ class ReplayBuffer: ...@@ -33,24 +33,17 @@ class ReplayBuffer:
for _ in range(self.config.batch_size): for _ in range(self.config.batch_size):
game_history = sample_game(self.buffer) game_history = sample_game(self.buffer)
game_pos = sample_position(game_history) game_pos = sample_position(game_history)
actions = game_history.history[
game_pos : game_pos + self.config.num_unroll_steps value, reward, policy, actions = make_target(
]
# Repeat precedent action to make 'actions' of length 'num_unroll_steps'
actions.extend(
[
actions[-1]
for _ in range(self.config.num_unroll_steps - len(actions) + 1)
]
)
observation_batch.append(game_history.observation_history[game_pos])
action_batch.append(actions)
value, reward, policy = make_target(
game_history, game_history,
game_pos, game_pos,
self.config.num_unroll_steps, self.config.num_unroll_steps,
self.config.td_steps, self.config.td_steps,
self.config.discount,
) )
observation_batch.append(game_history.observation_history[game_pos])
action_batch.append(actions)
value_batch.append(value) value_batch.append(value)
reward_batch.append(reward) reward_batch.append(reward)
policy_batch.append(policy) policy_batch.append(policy)
...@@ -63,7 +56,7 @@ def sample_game(buffer): ...@@ -63,7 +56,7 @@ def sample_game(buffer):
Sample game from buffer either uniformly or according to some priority. Sample game from buffer either uniformly or according to some priority.
""" """
# TODO: sample with probability link to the highest difference between real and # TODO: sample with probability link to the highest difference between real and
# predicted value (see paper appendix Training) # predicted value (See paper appendix Training)
return numpy.random.choice(buffer) return numpy.random.choice(buffer)
...@@ -75,29 +68,29 @@ def sample_position(game_history): ...@@ -75,29 +68,29 @@ def sample_position(game_history):
return numpy.random.choice(range(len(game_history.rewards))) return numpy.random.choice(range(len(game_history.rewards)))
def make_target(game_history, state_index, num_unroll_steps, td_steps): def make_target(game_history, state_index, num_unroll_steps, td_steps, discount):
""" """
The value target is the discounted root value of the search tree td_steps into the The value target is the discounted root value of the search tree td_steps into the
future, plus the discounted sum of all rewards until then. future, plus the discounted sum of all rewards until then.
""" """
target_values, target_rewards, target_policies = [], [], [] target_values, target_rewards, target_policies, actions = [], [], [], []
for current_index in range(state_index, state_index + num_unroll_steps + 1): for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps bootstrap_index = current_index + td_steps
if bootstrap_index < len(game_history.root_values): if bootstrap_index < len(game_history.root_values):
value = ( value = game_history.root_values[bootstrap_index] * discount ** td_steps
game_history.root_values[bootstrap_index]
* game_history.discount ** td_steps
)
else: else:
value = 0 value = 0
for i, reward in enumerate(game_history.rewards[current_index:bootstrap_index]): for i, reward in enumerate(game_history.rewards[current_index:bootstrap_index]):
value += reward * game_history.discount ** i value += reward * discount ** i
if current_index < len(game_history.root_values): if current_index < len(game_history.root_values):
target_values.append(0.25 * value) # TODO: Scale and transform the reward and the value, don't forget the invert transform in inference time (See paper appendix Network Architecture)
# Value target could be scaled by 0.25 (See paper appendix Reanalyze)
target_values.append(value)
target_rewards.append(game_history.rewards[current_index]) target_rewards.append(game_history.rewards[current_index])
target_policies.append(game_history.child_visits[current_index]) target_policies.append(game_history.child_visits[current_index])
actions.append(game_history.action_history[current_index])
else: else:
# States past the end of games are treated as absorbing states # States past the end of games are treated as absorbing states
target_values.append(0) target_values.append(0)
...@@ -109,5 +102,6 @@ def make_target(game_history, state_index, num_unroll_steps, td_steps): ...@@ -109,5 +102,6 @@ def make_target(game_history, state_index, num_unroll_steps, td_steps):
for _ in range(len(game_history.child_visits[0])) for _ in range(len(game_history.child_visits[0]))
] ]
) )
actions.append(game_history.action_history[-1])
return target_values, target_rewards, target_policies return target_values, target_rewards, target_policies, actions
import copy
import math import math
import time import time
import copy
import numpy import numpy
import ray import ray
import torch import torch
...@@ -26,11 +27,10 @@ class SelfPlay: ...@@ -26,11 +27,10 @@ class SelfPlay:
self.config.hidden_size, self.config.hidden_size,
) )
self.model.set_weights(initial_weights) self.model.set_weights(initial_weights)
self.model.to(torch.device('cpu')) self.model.to(torch.device("cpu"))
self.model.eval() self.model.eval()
def continuous_self_play(self, shared_storage, replay_buffer, test_mode=False): def continuous_self_play(self, shared_storage, replay_buffer, test_mode=False):
with torch.no_grad():
while True: while True:
self.model.set_weights( self.model.set_weights(
copy.deepcopy(ray.get(shared_storage.get_weights.remote())) copy.deepcopy(ray.get(shared_storage.get_weights.remote()))
...@@ -63,12 +63,18 @@ class SelfPlay: ...@@ -63,12 +63,18 @@ class SelfPlay:
""" """
Play one game with actions based on the Monte Carlo tree search at each moves. Play one game with actions based on the Monte Carlo tree search at each moves.
""" """
game_history = GameHistory(self.config.discount) game_history = GameHistory()
observation = self.game.reset() observation = self.game.reset()
game_history.observation_history.append(observation) game_history.observation_history.append(observation)
done = False done = False
while not done and len(game_history.history) < self.config.max_moves: with torch.no_grad():
root = MCTS(self.config).run(self.model, observation, True if temperature else False) while not done and len(game_history.action_history) < self.config.max_moves:
root = MCTS(self.config).run(
self.model,
observation,
self.game.to_play(),
True if temperature else False,
)
action = select_action(root, temperature) action = select_action(root, temperature)
...@@ -79,7 +85,7 @@ class SelfPlay: ...@@ -79,7 +85,7 @@ class SelfPlay:
game_history.observation_history.append(observation) game_history.observation_history.append(observation)
game_history.rewards.append(reward) game_history.rewards.append(reward)
game_history.history.append(action) game_history.action_history.append(action)
game_history.store_search_statistics(root, self.config.action_space) game_history.store_search_statistics(root, self.config.action_space)
self.game.close() self.game.close()
...@@ -98,7 +104,7 @@ def select_action(node, temperature): ...@@ -98,7 +104,7 @@ def select_action(node, temperature):
if temperature == 0: if temperature == 0:
action_pos = numpy.argmax(visit_counts[0]) action_pos = numpy.argmax(visit_counts[0])
else: else:
# See paper Data Generation appendix # See paper appendix Data Generation
visit_count_distribution = visit_counts[0] ** (1 / temperature) visit_count_distribution = visit_counts[0] ** (1 / temperature)
visit_count_distribution = visit_count_distribution / sum( visit_count_distribution = visit_count_distribution / sum(
visit_count_distribution visit_count_distribution
...@@ -125,7 +131,7 @@ class MCTS: ...@@ -125,7 +131,7 @@ class MCTS:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
def run(self, model, observation, add_exploration_noise): def run(self, model, observation, to_play, add_exploration_noise):
""" """
At the root of the search tree we use the representation function to obtain a At the root of the search tree we use the representation function to obtain a
hidden state given the current observation. hidden state given the current observation.
...@@ -143,7 +149,11 @@ class MCTS: ...@@ -143,7 +149,11 @@ class MCTS:
observation observation
) )
root.expand( root.expand(
self.config.action_space, expected_reward, policy_logits, hidden_state self.config.action_space,
to_play,
expected_reward,
policy_logits,
hidden_state,
) )
if add_exploration_noise: if add_exploration_noise:
root.add_exploration_noise( root.add_exploration_noise(
...@@ -154,6 +164,7 @@ class MCTS: ...@@ -154,6 +164,7 @@ class MCTS:
min_max_stats = MinMaxStats() min_max_stats = MinMaxStats()
for _ in range(self.config.num_simulations): for _ in range(self.config.num_simulations):
virtual_to_play = to_play
node = root node = root
search_path = [node] search_path = [node]
...@@ -162,6 +173,12 @@ class MCTS: ...@@ -162,6 +173,12 @@ class MCTS:
last_action = action last_action = action
search_path.append(node) search_path.append(node)
# Players play turn by turn
if virtual_to_play + 1 < len(self.config.players):
virtual_to_play = self.config.players[virtual_to_play + 1]
else:
virtual_to_play = self.config.players[0]
# Inside the search tree we use the dynamics function to obtain the next hidden # Inside the search tree we use the dynamics function to obtain the next hidden
# state given an action and the previous hidden state # state given an action and the previous hidden state
parent = search_path[-2] parent = search_path[-2]
...@@ -169,9 +186,15 @@ class MCTS: ...@@ -169,9 +186,15 @@ class MCTS:
parent.hidden_state, parent.hidden_state,
torch.tensor([[last_action]]).to(parent.hidden_state.device), torch.tensor([[last_action]]).to(parent.hidden_state.device),
) )
node.expand(self.config.action_space, reward, policy_logits, hidden_state) node.expand(
self.config.action_space,
virtual_to_play,
reward,
policy_logits,
hidden_state,
)
self.backpropagate(search_path, value.item(), min_max_stats) self.backpropagate(search_path, value.item(), to_play, min_max_stats)
return root return root
...@@ -202,13 +225,13 @@ class MCTS: ...@@ -202,13 +225,13 @@ class MCTS:
return prior_score + value_score return prior_score + value_score
def backpropagate(self, search_path, value, min_max_stats): def backpropagate(self, search_path, value, to_play, min_max_stats):
""" """
At the end of a simulation, we propagate the evaluation all the way up the tree At the end of a simulation, we propagate the evaluation all the way up the tree
to the root. to the root.
""" """
for node in search_path: for node in search_path:
node.value_sum += value # if node.to_play == to_play else -value node.value_sum += value if node.to_play == to_play else -value
node.visit_count += 1 node.visit_count += 1
min_max_stats.update(node.value()) min_max_stats.update(node.value())
...@@ -233,11 +256,12 @@ class Node: ...@@ -233,11 +256,12 @@ class Node:
return 0 return 0
return self.value_sum / self.visit_count return self.value_sum / self.visit_count
def expand(self, actions, reward, policy_logits, hidden_state): def expand(self, actions, to_play, reward, policy_logits, hidden_state):
""" """
We expand a node using the value, reward and policy prediction obtained from the We expand a node using the value, reward and policy prediction obtained from the
neural network. neural network.
""" """
self.to_play = to_play
self.reward = reward self.reward = reward
self.hidden_state = hidden_state self.hidden_state = hidden_state
policy = {a: math.exp(policy_logits[0][a]) for a in actions} policy = {a: math.exp(policy_logits[0][a]) for a in actions}
...@@ -262,13 +286,12 @@ class GameHistory: ...@@ -262,13 +286,12 @@ class GameHistory:
Store only usefull information of a self-play game. Store only usefull information of a self-play game.
""" """
def __init__(self, discount): def __init__(self):
self.observation_history = [] self.observation_history = []
self.history = [] self.action_history = []
self.rewards = [] self.rewards = []
self.child_visits = [] self.child_visits = []
self.root_values = [] self.root_values = []
self.discount = discount
def store_search_statistics(self, root, action_space): def store_search_statistics(self, root, action_space):
sum_visits = sum(child.visit_count for child in root.children.values()) sum_visits = sum(child.visit_count for child in root.children.values())
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import models import models
@ray.remote(num_gpus=1 if torch.cuda.is_available() else 0) @ray.remote
class Trainer: class Trainer:
""" """
Class which run in a dedicated thread to train a neural network and save it Class which run in a dedicated thread to train a neural network and save it
...@@ -64,12 +64,7 @@ class Trainer: ...@@ -64,12 +64,7 @@ class Trainer:
""" """
Perform one training step. Perform one training step.
""" """
# Update learning rate self.update_lr()
lr = self.config.lr_init * self.config.lr_decay_rate ** (
self.training_step / self.config.lr_decay_steps
)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
( (
observation_batch, observation_batch,
...@@ -116,9 +111,7 @@ class Trainer: ...@@ -116,9 +111,7 @@ class Trainer:
reward_loss += current_reward_loss reward_loss += current_reward_loss
policy_loss += current_policy_loss policy_loss += current_policy_loss
loss = ( loss = (value_loss + reward_loss + policy_loss).mean()
value_loss + reward_loss + policy_loss
).mean()
# Scale gradient by number of unroll steps (See paper Training appendix) # Scale gradient by number of unroll steps (See paper Training appendix)
loss.register_hook(lambda grad: grad * 1 / self.config.num_unroll_steps) loss.register_hook(lambda grad: grad * 1 / self.config.num_unroll_steps)
...@@ -136,6 +129,16 @@ class Trainer: ...@@ -136,6 +129,16 @@ class Trainer:
policy_loss.mean().item(), policy_loss.mean().item(),
) )
def update_lr(self):
"""
Update learning rate
"""
lr = self.config.lr_init * self.config.lr_decay_rate ** (
self.training_step / self.config.lr_decay_steps
)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def loss_function( def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy value, reward, policy_logits, target_value, target_reward, target_policy
...@@ -143,5 +146,7 @@ def loss_function( ...@@ -143,5 +146,7 @@ def loss_function(
# TODO: paper promotes cross entropy instead of MSE # TODO: paper promotes cross entropy instead of MSE
value_loss = torch.nn.MSELoss()(value, target_value) value_loss = torch.nn.MSELoss()(value, target_value)
reward_loss = torch.nn.MSELoss()(reward, target_reward) reward_loss = torch.nn.MSELoss()(reward, target_reward)
policy_loss = torch.mean(torch.sum(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits), 1)) policy_loss = torch.mean(
torch.sum(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits), 1)
)
return value_loss, reward_loss, policy_loss return value_loss, reward_loss, policy_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment