Commit 06564837 by Werner Duvaud

Refactored

parent d8388353
...@@ -7,17 +7,27 @@ ...@@ -7,17 +7,27 @@
# MuZero General # MuZero General
A flexible, commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v1/anc/pseudocode.py). A flexible, commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v1/anc/pseudocode.py).
It is designed to be easily adaptable for every games or reinforcement learning environnements (like [gym](https://github.com/openai/gym)). You only need to edit the game file with the parameters and the game class. Please refer to the documentation and the tutorial. It is designed to be easily adaptable for every games or reinforcement learning environments (like [gym](https://github.com/openai/gym)). You only need to edit the game file with the parameters and the game class. Please refer to the documentation and the [example](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py).
MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games whithout knowing the rules. It only know actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122). MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games without knowing the rules. It only knows actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122).
It uses [PyTorch](https://github.com/pytorch/pytorch) and [Ray](https://github.com/ray-project/ray) for running the different components simultaneously. There is a complete GPU support.
There are four "actors" which are classes that run simultaneously in a dedicated thread.
The shared storage holds the latest neural network weights, the self-play uses those weights to generate self-play games and store them in the replay buffer. Finally, those games are used to train a network and store the weights in the shared storage. The circle is complete.
Those components are launched and managed from the MuZero class in muzero.py and the structure of the neural network is defined in models.py.
All performances are tracked and displayed in real time in tensorboard.
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/tree/master/pretrained/pretrained/cartpole_training_summary.png)
It uses [PyTorch](https://github.com/pytorch/pytorch) and [Ray](https://github.com/ray-project/ray) for self-playing on multiple threads. A synchronous mode (easier for debug) will be released. There is a complete GPU support.
The code has three parts, muzero.py with the entry class, self-play.py with the replay-buffer and the MCTS classes, and network.py with the neural networks and the shared storage classes.
## Games already implemented with pretrained network available ## Games already implemented with pretrained network available
* Lunar Lander * Lunar Lander
* Cartpole * Cartpole
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/tree/master/games/lunarlander_training_preview.png)
## Getting started ## Getting started
### Installation ### Installation
```bash ```bash
...@@ -26,32 +36,35 @@ pip install -r requirements.txt ...@@ -26,32 +36,35 @@ pip install -r requirements.txt
``` ```
### Training ### Training
Edit the end of muzero.py : Edit the end of muzero.py:
```python ```python
muzero = Muzero("cartpole") muzero = Muzero("cartpole")
muzero.train() muzero.train()
``` ```
Then run : Then run:
```bash ```bash
python muzero.py python muzero.py
``` ```
To visualize the training results, run in a new bash:
```bash
tensorboard --logdir ./
```
### Testing ### Testing
Edit the end of muzero.py : Edit the end of muzero.py:
```python ```python
muzero = Muzero("cartpole") muzero = Muzero("cartpole")
muzero.load_model() muzero.load_model()
muzero.test() muzero.test()
``` ```
Then run : Then run:
```bash ```bash
python muzero.py python muzero.py
``` ```
## Coming soon ## Coming soon
* [ ] Convolutionnal / Atari mode * [ ] Atari mode with residual network
* [ ] Performance tracking * [ ] Live test policy & value tracking
* [ ] Synchronous mode
* [ ] [Open spiel](https://github.com/deepmind/open_spiel) integration * [ ] [Open spiel](https://github.com/deepmind/open_spiel) integration
* [ ] Checkers game * [ ] Checkers game
* [ ] TensorFlow mode * [ ] TensorFlow mode
......
import gym import gym
import numpy import numpy
import torch
class MuZeroConfig: class MuZeroConfig:
...@@ -31,19 +30,17 @@ class MuZeroConfig: ...@@ -31,19 +30,17 @@ class MuZeroConfig:
self.max_known_bound = None self.max_known_bound = None
### Network ### Network
self.encoding_size = 32 self.encoding_size = 64
self.hidden_size = 64 self.hidden_size = 32
# Training # Training
self.results_path = "./pretrained" # Path to store the model weights self.results_path = "./pretrained" # Path to store the model weights
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Automatically use GPU instead of CPU if available self.training_steps = 1000 # Total number of training steps (ie weights update according to a batch)
self.training_steps = 400 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.test_interval = 20 # Number of training steps before evaluating the network on the game to track the performance
self.test_episodes = 2 # Number of game played to evaluate the network self.test_episodes = 2 # Number of game played to evaluate the network
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in memory (in the replay buffer) self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.weight_decay = 1e-4 # L2 weights regularization self.weight_decay = 1e-4 # L2 weights regularization
...@@ -54,7 +51,7 @@ class MuZeroConfig: ...@@ -54,7 +51,7 @@ class MuZeroConfig:
self.lr_decay_rate = 0.1 self.lr_decay_rate = 0.1
self.lr_decay_steps = 3500 self.lr_decay_steps = 3500
def visit_softmax_temperature_fn(self, num_moves, trained_steps): def visit_softmax_temperature_fn(self, trained_steps):
""" """
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses. Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
The smaller it is, the more likely the best action (ie with the highest visit count) is chosen. The smaller it is, the more likely the best action (ie with the highest visit count) is chosen.
......
...@@ -13,7 +13,7 @@ class MuZeroConfig: ...@@ -13,7 +13,7 @@ class MuZeroConfig:
### Self-Play ### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 100 # Maximum number of moves if game is not finished before self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward self.discount = 0.997 # Chronological discount of the reward
...@@ -31,30 +31,28 @@ class MuZeroConfig: ...@@ -31,30 +31,28 @@ class MuZeroConfig:
self.max_known_bound = None self.max_known_bound = None
### Network ### Network
self.encoding_size = 16 self.encoding_size = 64
self.hidden_size = 8 self.hidden_size = 32
# Training # Training
self.results_path = "./pretrained" # Path to store the model weights self.results_path = "./pretrained" # Path to store the model weights
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Automatically use GPU instead of CPU if available self.training_steps = 20000 # Total number of training steps (ie weights update according to a batch)
self.training_steps = 700 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 50 # Number of game moves to keep for every batch element self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.test_interval = 20 # Number of training steps before evaluating the network on the game to track the performance
self.test_episodes = 2 # Number of game played to evaluate the network self.test_episodes = 2 # Number of game played to evaluate the network
self.checkpoint_interval = 20 # Number of training steps before using the model for sef-playing self.checkpoint_interval = 20 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in memory (in the replay buffer) self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 50 # Number of steps in the futur to take into account for calculating the target value self.td_steps = 100 # Number of steps in the futur to take into account for calculating the target value
self.weight_decay = 1e-4 # L2 weights regularization self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9 self.momentum = 0.9
# Exponential learning rate schedule # Exponential learning rate schedule
self.lr_init = 0.005 # Initial learning rate self.lr_init = 0.005 # Initial learning rate
self.lr_decay_rate = 0.01 self.lr_decay_rate = 0.1
self.lr_decay_steps = 3500 self.lr_decay_steps = 3500
def visit_softmax_temperature_fn(self, num_moves, trained_steps): def visit_softmax_temperature_fn(self, trained_steps):
""" """
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses. Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
The smaller it is, the more likely the best action (ie with the highest visit count) is chosen. The smaller it is, the more likely the best action (ie with the highest visit count) is chosen.
...@@ -62,14 +60,12 @@ class MuZeroConfig: ...@@ -62,14 +60,12 @@ class MuZeroConfig:
Returns: Returns:
Positive float. Positive float.
""" """
if trained_steps < 0.25 * self.training_steps: if trained_steps < 0.5 * self.training_steps:
return 1000 return 1.0
elif trained_steps < 0.5 * self.training_steps:
return 1
elif trained_steps < 0.75 * self.training_steps: elif trained_steps < 0.75 * self.training_steps:
return 0.5 return 0.5
else: else:
return 0.1 return 0.25
class Game: class Game:
......
import torch
class FullyConnectedNetwork(torch.nn.Module):
def __init__(
self, input_size, layers_sizes, output_size, activation=torch.nn.Tanh()
):
super(FullyConnectedNetwork, self).__init__()
layers_sizes.insert(0, input_size)
layers = []
if 1 < len(layers_sizes):
for i in range(len(layers_sizes) - 1):
layers.extend(
[
torch.nn.Linear(layers_sizes[i], layers_sizes[i + 1]),
torch.nn.ReLU(),
]
)
layers.append(torch.nn.Linear(layers_sizes[-1], output_size))
if activation:
layers.append(activation)
self.layers = torch.nn.ModuleList(layers)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
# TODO: unified residual network
class MuZeroNetwork(torch.nn.Module):
def __init__(self, observation_size, action_space_size, encoding_size, hidden_size):
super().__init__()
self.action_space_size = action_space_size
self.representation_network = FullyConnectedNetwork(
observation_size, [], encoding_size
)
self.dynamics_encoded_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], encoding_size
)
self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], 1
)
self.prediction_policy_network = FullyConnectedNetwork(
encoding_size, [], self.action_space_size, activation=None
)
self.prediction_value_network = FullyConnectedNetwork(
encoding_size, [], 1, activation=None
)
def prediction(self, encoded_state):
policy_logit = self.prediction_policy_network(encoded_state)
value = self.prediction_value_network(encoded_state)
return policy_logit, value
def representation(self, observation):
return self.representation_network(observation)
def dynamics(self, encoded_state, action):
action_one_hot = (
torch.zeros((action.shape[0], self.action_space_size))
.to(action.device)
.float()
)
action_one_hot.scatter_(1, action.long(), 1.0)
x = torch.cat((encoded_state, action_one_hot), dim=1)
next_encoded_state = self.dynamics_encoded_state_network(x)
reward = self.dynamics_reward_network(x)
return next_encoded_state, reward
def initial_inference(self, observation):
encoded_state = self.representation(observation)
policy_logit, value = self.prediction(encoded_state)
return (
value,
torch.zeros(len(observation)).to(observation.device),
policy_logit,
encoded_state,
)
def recurrent_inference(self, encoded_state, action):
next_encoded_state, reward = self.dynamics(encoded_state, action)
policy_logit, value = self.prediction(next_encoded_state)
return value, reward, policy_logit, next_encoded_state
def get_weights(self):
return {key: value.cpu() for key, value in self.state_dict().items()}
def set_weights(self, weights):
self.load_state_dict(weights)
\ No newline at end of file
import copy
import datetime
import importlib import importlib
import os import os
import time import time
...@@ -5,9 +7,13 @@ import time ...@@ -5,9 +7,13 @@ import time
import numpy import numpy
import ray import ray
import torch import torch
from torch.utils.tensorboard import SummaryWriter
import network import models
import replay_buffer
import self_play import self_play
import shared_storage
import trainer
class MuZero: class MuZero:
...@@ -44,144 +50,123 @@ class MuZero: ...@@ -44,144 +50,123 @@ class MuZero:
numpy.random.seed(self.config.seed) numpy.random.seed(self.config.seed)
torch.manual_seed(self.config.seed) torch.manual_seed(self.config.seed)
self.best_model = network.Network( # Used to initialize components when continuing a former training
self.muzero_weights = models.MuZeroNetwork(
self.config.observation_shape, self.config.observation_shape,
len(self.config.action_space), len(self.config.action_space),
self.config.encoding_size, self.config.encoding_size,
self.config.hidden_size, self.config.hidden_size,
) ).get_weights()
self.training_steps = 0
def train(self): def train(self):
# Initialize and launch components that work simultaneously
ray.init() ray.init()
model = self.best_model writer = SummaryWriter(
model.train() os.path.join(self.config.results_path, self.game_name + "_summary")
storage = network.SharedStorage.remote(model)
replay_buffer = self_play.ReplayBuffer.remote(self.config)
for seed in range(self.config.num_actors):
self_play.run_selfplay.remote(
self.Game,
self.config,
storage,
replay_buffer,
model,
self.config.seed + seed,
) )
# Initialize network for training # Initialize workers
model = model.to(self.config.device) training_worker = trainer.Trainer.remote(
optimizer = torch.optim.SGD( copy.deepcopy(self.muzero_weights),
model.parameters(), self.training_steps,
lr=self.config.lr_init, self.config,
momentum=self.config.momentum, # Train on GPU if available
weight_decay=self.config.weight_decay, "cuda" if torch.cuda.is_available() else "cpu",
) )
shared_storage_worker = shared_storage.SharedStorage.remote(
# Wait for replay buffer to be non-empty copy.deepcopy(self.muzero_weights),
while ray.get(replay_buffer.length.remote()) == 0: self.training_steps,
time.sleep(0.1) self.game_name,
self.config,
# Training loop
best_test_rewards = None
for training_step in range(self.config.training_steps):
model.train()
storage.set_training_step.remote(training_step)
# Make the model available for self-play
if training_step % self.config.checkpoint_interval == 0:
storage.set_weights.remote(model.state_dict())
# Update learning rate
lr = self.config.lr_init * self.config.lr_decay_rate ** (
training_step / self.config.lr_decay_steps
) )
for param_group in optimizer.param_groups: replay_buffer_worker = replay_buffer.ReplayBuffer.remote(self.config)
param_group["lr"] = lr self_play_workers = [
self_play.SelfPlay.remote(
# Train on a batch. copy.deepcopy(self.muzero_weights),
batch = ray.get( self.Game(self.config.seed + seed),
replay_buffer.sample_batch.remote( self.config,
self.config.num_unroll_steps, self.config.td_steps "cpu",
) )
for seed in range(self.config.num_actors)
]
test_worker = self_play.SelfPlay.remote(
copy.deepcopy(self.muzero_weights), self.Game(), self.config, "cpu",
) )
loss = network.update_weights(optimizer, model, batch, self.config)
# Test the current model and save it on disk if it is the best # Launch workers
if training_step % self.config.test_interval == 0: [
total_test_rewards = self.test(model=model, render=False) self_play_worker.continuous_self_play.remote(
if best_test_rewards is None or sum(total_test_rewards) >= sum( shared_storage_worker, replay_buffer_worker
best_test_rewards )
): for self_play_worker in self_play_workers
self.best_model = model ]
best_test_rewards = total_test_rewards test_worker.continuous_self_play.remote(shared_storage_worker, None, True)
self.save_model() training_worker.continuous_update_weights.remote(
replay_buffer_worker, shared_storage_worker
)
# Loop for monitoring in real time the workers
print( print(
"Training step: {}\nBuffer Size: {}\nLearning rate: {}\nLoss: {}\nLast test score: {}\nBest sest score: {}\n".format( "Run tensorboard --logdir ./ and go to http://localhost:6006/ to track the training performance"
training_step,
ray.get(replay_buffer.length.remote()),
lr,
loss,
str(total_test_rewards),
str(best_test_rewards),
) )
counter = 0
infos = ray.get(shared_storage_worker.get_infos.remote())
while infos["training_step"] < self.config.training_steps:
# Get and save real time performance
infos = ray.get(shared_storage_worker.get_infos.remote())
writer.add_scalar(
"1.Total reward/Total reward", infos["total_reward"], counter
) )
writer.add_scalar(
# Finally, save the latest network in the shared storage and end the self-play "2.Workers/Self played games",
storage.set_weights.remote(model.state_dict()) ray.get(replay_buffer_worker.get_self_play_count.remote()),
counter,
)
writer.add_scalar(
"2.Workers/Training steps", infos["training_step"], counter
)
writer.add_scalar("3.Loss/1.Total loss", infos["total_loss"], counter)
writer.add_scalar("3.Loss details/Value loss", infos["value_loss"], counter)
writer.add_scalar(
"3.Loss details/Reward loss", infos["reward_loss"], counter
)
writer.add_scalar(
"3.Loss details/Policy loss", infos["policy_loss"], counter
)
counter += 1
time.sleep(1)
self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
ray.shutdown() ray.shutdown()
def test(self, model=None, render=True): def test(self, render=True):
if not model: """
model = self.best_model Test the model in a dedicated thread.
"""
model.to(self.config.device) ray.init()
self_play_workers = self_play.SelfPlay.remote(
copy.deepcopy(self.muzero_weights), self.Game(), self.config, "cpu",
)
test_rewards = [] test_rewards = []
game = self.Game()
model.eval()
with torch.no_grad(): with torch.no_grad():
for _ in range(self.config.test_episodes): for _ in range(self.config.test_episodes):
observation = game.reset() history = ray.get(self_play_workers.self_play.remote(0, render))
done = False test_rewards.append(sum(history.rewards))
total_reward = 0 ray.shutdown()
while not done:
if render:
game.render()
root = self_play.MCTS(self.config).run(model, observation, False)
action = self_play.select_action(root, temperature=0)
observation, reward, done = game.step(action)
total_reward += reward
test_rewards.append(total_reward)
return test_rewards return test_rewards
def save_model(self, model=None, path=None): def load_model(self, path=None, training_step=0):
if not model: # TODO: why pretrained model is degradated during the new train
model = self.best_model
if not path: if not path:
path = os.path.join(self.config.results_path, self.game_name) path = os.path.join(self.config.results_path, self.game_name)
torch.save(model.state_dict(), path)
def load_model(self, path=None):
if not path:
path = os.path.join(self.config.results_path, self.game_name)
self.best_model = network.Network(
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
)
try: try:
self.best_model.load_state_dict(torch.load(path)) self.muzero_weights = torch.load(path)
self.training_step = training_step
except FileNotFoundError: except FileNotFoundError:
print("There is no model saved in {}.".format(path)) print("There is no model saved in {}.".format(path))
if __name__ == "__main__": if __name__ == "__main__":
# Load the game and the parameters from ./games/file_name.py
muzero = MuZero("cartpole") muzero = MuZero("cartpole")
muzero.load_model()
muzero.train() muzero.train()
# muzero.load_model()
muzero.test() muzero.test()
import ray
import torch
class Network(torch.nn.Module):
def __init__(self, input_size, action_space_n, encoding_size, hidden_size):
super().__init__()
self.action_space_n = action_space_n
self.representation_network = FullyConnectedNetwork(
input_size, [], encoding_size
)
self.dynamics_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_n, [hidden_size], encoding_size
)
self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_n, [hidden_size], 1
)
self.prediction_actor_network = FullyConnectedNetwork(
encoding_size, [], self.action_space_n, activation=None
)
self.prediction_value_network = FullyConnectedNetwork(
encoding_size, [], 1, activation=None
)
def prediction(self, state):
actor_logit = self.prediction_actor_network(state)
value = self.prediction_value_network(state)
return actor_logit, value
def representation(self, observation):
return self.representation_network(observation)
def dynamics(self, state, action):
action_one_hot = (
torch.zeros((action.shape[0], self.action_space_n))
.to(action.device)
.float()
)
action_one_hot.scatter_(1, action.long(), 1.0)
x = torch.cat((state, action_one_hot), dim=1)
next_state = self.dynamics_state_network(x)
reward = self.dynamics_reward_network(x)
return next_state, reward
def initial_inference(self, observation):
state = self.representation(observation)
actor_logit, value = self.prediction(state)
return (
value,
torch.zeros(len(observation)).to(observation.device),
actor_logit,
state,
)
def recurrent_inference(self, hidden_state, action):
state, reward = self.dynamics(hidden_state, action)
actor_logit, value = self.prediction(state)
return value, reward, actor_logit, state
def update_weights(optimizer, model, batch, config):
observation_batch, action_batch, target_reward, target_value, target_policy = batch
observation_batch = torch.tensor(observation_batch).float().to(config.device)
action_batch = torch.tensor(action_batch).float().to(config.device).unsqueeze(-1)
target_value = torch.tensor(target_value).float().to(config.device)
target_reward = torch.tensor(target_reward).float().to(config.device)
target_policy = torch.tensor(target_policy).float().to(config.device)
value, reward, policy_logits, hidden_state = model.initial_inference(
observation_batch
)
predictions = [(value, reward, policy_logits)]
for action_i in range(config.num_unroll_steps):
value, reward, policy_logits, hidden_state = model.recurrent_inference(
hidden_state, action_batch[:, action_i]
)
predictions.append((value, reward, policy_logits))
loss = 0
for i, prediction in enumerate(predictions):
value, reward, policy_logits = prediction
loss += loss_function(
value.squeeze(-1),
reward.squeeze(-1),
policy_logits,
target_value[:, i],
target_reward[:, i],
target_policy[:, i, :],
)
# Scale gradient by number of unroll steps (See paper Training appendix)
loss = loss.mean() / config.num_unroll_steps
# Optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy
):
# TODO: paper promotes cross entropy instead of MSE
value_loss = torch.nn.MSELoss(reduction="none")(value, target_value)
reward_loss = torch.nn.MSELoss(reduction="none")(reward, target_reward)
policy_loss = -(torch.log_softmax(policy_logits, dim=1) * target_policy).sum(1)
return value_loss + reward_loss + policy_loss
class FullyConnectedNetwork(torch.nn.Module):
def __init__(
self, input_size, layers_sizes, output_size, activation=torch.nn.Tanh()
):
super(FullyConnectedNetwork, self).__init__()
layers_sizes.insert(0, input_size)
layers = []
if 1 < len(layers_sizes):
for i in range(len(layers_sizes) - 1):
layers.extend(
[
torch.nn.Linear(layers_sizes[i], layers_sizes[i + 1]),
torch.nn.ReLU(),
]
)
layers.append(torch.nn.Linear(layers_sizes[-1], output_size))
if activation:
layers.append(activation)
self.layers = torch.nn.ModuleList(layers)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
@ray.remote
class SharedStorage:
def __init__(self, model):
self.training_step = 0
self.model = model
def get_weights(self):
return self.model.state_dict()
def set_weights(self, weights):
return self.model.load_state_dict(weights)
def set_training_step(self, training_step):
self.training_step = training_step
def get_training_step(self):
return self.training_step
{
"nbformat": 4,
"nbformat_minor": 2,
"metadata": {
"language_info": {
"name": "python",
"codemirror_mode": {
"name": "ipython",
"version": 3
}
},
"orig_nbformat": 2,
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"npconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": 3
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Google colab imports\n",
"!pip install -r requirements.txt\n",
"!pip uninstall -y pyarrow\n",
"# If you have an import issue with ray, restart the environment (execution menu)\n",
"\n",
"# You must have the repository imported along with your notebook. \n",
"# For google colab, click on \">\" buttton (left) and import files (muzero.py, self_play.py, ...).\n",
"import muzero as mz"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Train on cartpole game\n",
"muzero = mz.MuZero(\"cartpole\")\n",
"muzero.load_model()\n",
"muzero.train()\n",
"muzero.test()"
]
}
]
}
\ No newline at end of file
No preview for this file type
No preview for this file type
import numpy
import ray
@ray.remote
class ReplayBuffer:
"""
Class which run in a dedicated thread to store played games and generate batch.
"""
def __init__(self, config):
self.config = config
self.buffer = []
self.self_play_count = 0
def save_game(self, game_history):
if len(self.buffer) > self.config.window_size:
self.buffer.pop(0)
self.buffer.append(game_history)
self.self_play_count += 1
def get_self_play_count(self):
return self.self_play_count
def get_batch(self):
observation_batch, action_batch, reward_batch, value_batch, policy_batch = (
[],
[],
[],
[],
[],
)
for _ in range(self.config.batch_size):
game_history = sample_game(self.buffer)
game_pos = sample_position(game_history)
actions = game_history.history[
game_pos : game_pos + self.config.num_unroll_steps
]
# Repeat precedent action to make "actions" of length "num_unroll_steps"
actions.extend(
[
actions[-1]
for _ in range(self.config.num_unroll_steps - len(actions) + 1)
]
)
observation_batch.append(game_history.observation_history[game_pos])
action_batch.append(actions)
value, reward, policy = make_target(
game_history,
game_pos,
self.config.num_unroll_steps,
self.config.td_steps,
)
value_batch.append(value)
reward_batch.append(reward)
policy_batch.append(policy)
return observation_batch, action_batch, value_batch, reward_batch, policy_batch
def sample_game(buffer):
"""
Sample game from buffer either uniformly or according to some priority.
"""
# TODO: sample with probability link to the highest difference between real and predicted value (see paper appendix Training)
return numpy.random.choice(buffer)
def sample_position(game_history):
"""
Sample position from game either uniformly or according to some priority.
"""
# TODO: according to some priority
return numpy.random.choice(range(len(game_history.rewards)))
def make_target(game_history, state_index, num_unroll_steps, td_steps):
"""
The value target is the discounted root value of the search tree td_steps into the future, plus the discounted sum of all rewards until then.
"""
target_values, target_rewards, target_policies = [], [], []
for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps
if bootstrap_index < len(game_history.root_values):
value = (
game_history.root_values[bootstrap_index]
* game_history.discount ** td_steps
)
else:
value = 0
for i, reward in enumerate(game_history.rewards[current_index:bootstrap_index]):
value += reward * game_history.discount ** i
if current_index < len(game_history.root_values):
target_values.append(value)
target_rewards.append(game_history.rewards[current_index])
target_policies.append(game_history.child_visits[current_index])
else:
# States past the end of games are treated as absorbing states
target_values.append(0)
target_rewards.append(0)
# Uniform policy to give the tensor a valid dimension
target_policies.append(
[
1 / len(game_history.child_visits[0])
for _ in range(len(game_history.child_visits[0]))
]
)
return target_values, target_rewards, target_policies
...@@ -4,43 +4,106 @@ import numpy ...@@ -4,43 +4,106 @@ import numpy
import ray import ray
import torch import torch
import models
@ray.remote # (num_gpus=1) # Uncoomment num_gpus and model.to(config.device) to self-play on GPU
def run_selfplay(Game, config, shared_storage, replay_buffer, model, seed): @ray.remote
class SelfPlay:
""" """
Function which run simultaneously in multiple threads and is continuously playing games and saving them to the replay-buffer. Class which run in a dedicated thread to play games and save them to the replay-buffer.
""" """
# model.to(config.device) # Uncoomment this line and num_gpus from the decorator to self-play on GPU def __init__(self, initial_weights, game, config, device):
self.config = config
self.game = game
# Initialize the network
self.model = models.MuZeroNetwork(
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
)
self.model.set_weights(initial_weights)
self.model.to(torch.device(device))
def continuous_self_play(self, shared_storage, replay_buffer, test_mode=False):
with torch.no_grad(): with torch.no_grad():
while True: while True:
# Initialize a self-play self.model.set_weights(ray.get(shared_storage.get_weights.remote()))
model.load_state_dict(ray.get(shared_storage.get_weights.remote()))
game = Game(seed) # Take the best action (no exploration) in test mode
done = False temperature = (
game_history = GameHistory(config.discount) 0
if test_mode
else self.config.visit_softmax_temperature_fn(
trained_steps=ray.get(shared_storage.get_infos.remote())[
"training_step"
]
)
)
# Self-play with actions based on the Monte Carlo tree search at each moves game_history = self.self_play(temperature, False)
observation = game.reset()
game_history.observation_history.append(observation)
while not done and len(game_history.history) < config.max_moves:
root = MCTS(config).run(model, observation, True)
temperature = config.visit_softmax_temperature_fn( # Save to the shared storage
num_moves=len(game_history.history), if test_mode:
trained_steps=ray.get(shared_storage.get_training_step.remote()), shared_storage.set_infos.remote(
"total_reward", sum(game_history.rewards)
) )
if not test_mode:
replay_buffer.save_game.remote(game_history)
def self_play(self, temperature, render):
"""
Play one game with actions based on the Monte Carlo tree search at each moves.
"""
game_history = GameHistory(self.config.discount)
observation = self.game.reset()
game_history.observation_history.append(observation)
done = False
while not done and len(game_history.history) < self.config.max_moves:
root = MCTS(self.config).run(self.model, observation, True)
action = select_action(root, temperature) action = select_action(root, temperature)
observation, reward, done = game.step(action) observation, reward, done = self.game.step(action)
if render:
self.game.render()
print("Press enter to step")
game_history.observation_history.append(observation) game_history.observation_history.append(observation)
game_history.rewards.append(reward) game_history.rewards.append(reward)
game_history.history.append(action) game_history.history.append(action)
game_history.store_search_statistics(root, config.action_space) game_history.store_search_statistics(root, self.config.action_space)
game.close() self.game.close()
# Save the game history return game_history
replay_buffer.save_game.remote(game_history)
def select_action(node, temperature):
"""
Select action according to the vivist count distribution and the temperature.
The temperature is changed dynamically with the visit_softmax_temperature function in the config.
"""
visit_counts = numpy.array(
[[child.visit_count, action] for action, child in node.children.items()]
).T
if temperature == 0:
action_pos = numpy.argmax(visit_counts[0])
else:
# See paper Data Generation appendix
visit_count_distribution = visit_counts[0] ** (1 / temperature)
visit_count_distribution = visit_count_distribution / sum(
visit_count_distribution
)
action_pos = numpy.random.choice(
len(visit_counts[1]), p=visit_count_distribution
)
if temperature == float("inf"):
action_pos = numpy.random.choice(len(visit_counts[1]))
return visit_counts[1][action_pos]
# Game independant # Game independant
...@@ -62,7 +125,10 @@ class MCTS: ...@@ -62,7 +125,10 @@ class MCTS:
""" """
root = Node(0) root = Node(0)
observation = ( observation = (
torch.from_numpy(observation).to(self.config.device).float().unsqueeze(0) torch.from_numpy(observation)
.float()
.unsqueeze(0)
.to(next(model.parameters()).device)
) )
_, expected_reward, policy_logits, hidden_state = model.initial_inference( _, expected_reward, policy_logits, hidden_state = model.initial_inference(
observation observation
...@@ -76,9 +142,7 @@ class MCTS: ...@@ -76,9 +142,7 @@ class MCTS:
exploration_fraction=self.config.root_exploration_fraction, exploration_fraction=self.config.root_exploration_fraction,
) )
min_max_stats = MinMaxStats( min_max_stats = MinMaxStats()
self.config.min_known_bound, self.config.max_known_bound
)
for _ in range(self.config.num_simulations): for _ in range(self.config.num_simulations):
node = root node = root
...@@ -141,32 +205,6 @@ class MCTS: ...@@ -141,32 +205,6 @@ class MCTS:
value = node.reward + self.config.discount * value value = node.reward + self.config.discount * value
def select_action(node, temperature, random=False):
"""
Select action according to the vivist count distribution and the temperature.
The temperature is changed dynamically with the visit_softmax_temperature function in the config.
"""
visit_counts = numpy.array(
[[child.visit_count, action] for action, child in node.children.items()]
).T
if temperature == 0:
action_pos = numpy.argmax(visit_counts[0])
else:
# See paper Data Generation appendix
visit_count_distribution = visit_counts[0] ** (1 / temperature)
visit_count_distribution = visit_count_distribution / sum(
visit_count_distribution
)
action_pos = numpy.random.choice(
len(visit_counts[1]), p=visit_count_distribution
)
if random:
action_pos = numpy.random.choice(len(visit_counts[1]))
return visit_counts[1][action_pos]
class Node: class Node:
def __init__(self, prior): def __init__(self, prior):
self.visit_count = 0 self.visit_count = 0
...@@ -231,105 +269,14 @@ class GameHistory: ...@@ -231,105 +269,14 @@ class GameHistory:
self.root_values.append(root.value()) self.root_values.append(root.value())
@ray.remote
class ReplayBuffer:
# Store list of game history and generate batch
def __init__(self, config):
self.window_size = config.window_size
self.batch_size = config.batch_size
self.buffer = []
def save_game(self, game_history):
if len(self.buffer) > self.window_size:
self.buffer.pop(0)
self.buffer.append(game_history)
def sample_batch(self, num_unroll_steps, td_steps):
observation_batch, action_batch, reward_batch, value_batch, policy_batch = (
[],
[],
[],
[],
[],
)
for _ in range(self.batch_size):
game_history = self.sample_game()
game_pos = self.sample_position(game_history)
actions = game_history.history[game_pos : game_pos + num_unroll_steps]
# Repeat precedent action to make "actions" of length "num_unroll_steps"
actions.extend(
[actions[-1] for _ in range(num_unroll_steps - len(actions) + 1)]
)
observation_batch.append(game_history.observation_history[game_pos])
action_batch.append(actions)
value, reward, policy = self.make_target(
game_history, game_pos, num_unroll_steps, td_steps
)
reward_batch.append(reward)
value_batch.append(value)
policy_batch.append(policy)
return observation_batch, action_batch, reward_batch, value_batch, policy_batch
def sample_game(self):
"""
Sample game from buffer either uniformly or according to some priority.
"""
# TODO: sample with probability link to the highest difference between real and predicted value (see paper appendix Training)
return self.buffer[numpy.random.choice(range(len(self.buffer)))]
def sample_position(self, game):
"""
Sample position from game either uniformly or according to some priority.
"""
# TODO: according to some priority
return numpy.random.choice(range(len(game.rewards)))
def make_target(self, game, state_index, num_unroll_steps, td_steps):
"""
The value target is the discounted root value of the search tree td_steps into the future, plus the discounted sum of all rewards until then.
"""
target_values, target_rewards, target_policies = [], [], []
for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps
if bootstrap_index < len(game.root_values):
value = game.root_values[bootstrap_index] * game.discount ** td_steps
else:
value = 0
for i, reward in enumerate(game.rewards[current_index:bootstrap_index]):
value += reward * game.discount ** i
if current_index < len(game.root_values):
target_values.append(value)
target_rewards.append(game.rewards[current_index])
target_policies.append(game.child_visits[current_index])
else:
# States past the end of games are treated as absorbing states
target_values.append(0)
target_rewards.append(0)
# Uniform policy to give the tensor a valid dimension
target_policies.append(
[
1 / len(game.child_visits[0])
for _ in range(len(game.child_visits[0]))
]
)
return target_values, target_rewards, target_policies
def length(self):
return len(self.buffer)
class MinMaxStats: class MinMaxStats:
""" """
A class that holds the min-max values of the tree. A class that holds the min-max values of the tree.
""" """
def __init__(self, min_value_bound, max_value_bound): def __init__(self):
self.maximum = min_value_bound if min_value_bound else -float("inf") self.maximum = -float("inf")
self.minimum = max_value_bound if max_value_bound else float("inf") self.minimum = float("inf")
def update(self, value): def update(self, value):
self.maximum = max(self.maximum, value) self.maximum = max(self.maximum, value)
......
import ray
import torch
import os
@ray.remote
class SharedStorage:
"""
Class which run in a dedicated thread to store the network weights and some information.
"""
def __init__(self, weights, training_step, game_name, config):
self.config = config
self.game_name = game_name
self.weights = weights
self.infos = {'training_step': training_step,
'total_reward': 0,
'total_loss': 0,
'value_loss': 0,
'reward_loss': 0,
'policy_loss': 0}
def get_weights(self):
return self.weights
def set_weights(self, weights, path=None):
self.weights = weights
if not path:
path = os.path.join(self.config.results_path, self.game_name)
torch.save(self.weights, path)
def get_infos(self):
return self.infos
def set_infos(self, key, value):
self.infos[key] = value
import time
import numpy
import ray
import torch
import models
@ray.remote(num_gpus=1)
class Trainer:
"""
Class which run in a dedicated thread to train a neural network and save it in the shared storage.
"""
def __init__(self, initial_weights, initial_training_step, config, device):
self.config = config
self.training_step = initial_training_step
# Initialize the network
self.model = models.MuZeroNetwork(
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
)
self.model.set_weights(initial_weights)
self.model.to(torch.device(device))
self.model.train()
self.optimizer = torch.optim.SGD(
self.model.parameters(),
lr=self.config.lr_init,
momentum=self.config.momentum,
weight_decay=self.config.weight_decay,
)
def continuous_update_weights(self, replay_buffer, shared_storage_worker):
# Wait for the replay buffer to be filled
while ray.get(replay_buffer.get_self_play_count.remote()) < 1:
time.sleep(0.1)
# Training loop
while True:
batch = ray.get(replay_buffer.get_batch.remote())
total_loss, value_loss, reward_loss, policy_loss = self.update_weights(
batch
)
# Save to the shared storage
if self.training_step % self.config.checkpoint_interval == 0:
shared_storage_worker.set_weights.remote(self.model.get_weights())
shared_storage_worker.set_infos.remote("training_step", self.training_step)
shared_storage_worker.set_infos.remote("total_loss", total_loss)
shared_storage_worker.set_infos.remote("value_loss", value_loss)
shared_storage_worker.set_infos.remote("reward_loss", reward_loss)
shared_storage_worker.set_infos.remote("policy_loss", policy_loss)
def update_weights(self, batch):
"""
Perform one training step.
"""
# Update learning rate
lr = self.config.lr_init * self.config.lr_decay_rate ** (
self.training_step / self.config.lr_decay_steps
)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
(
observation_batch,
action_batch,
target_value,
target_reward,
target_policy,
) = batch
device = next(self.model.parameters()).device
observation_batch = torch.tensor(observation_batch).float().to(device)
action_batch = torch.tensor(action_batch).float().to(device).unsqueeze(-1)
target_value = torch.tensor(target_value).float().to(device)
target_reward = torch.tensor(target_reward).float().to(device)
target_policy = torch.tensor(target_policy).float().to(device)
value, reward, policy_logits, hidden_state = self.model.initial_inference(
observation_batch
)
predictions = [(value, reward, policy_logits)]
for action_i in range(self.config.num_unroll_steps):
value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
hidden_state, action_batch[:, action_i]
)
predictions.append((value, reward, policy_logits))
# Compute losses
value_loss, reward_loss, policy_loss = (0, 0, 0)
for i, prediction in enumerate(predictions):
value, reward, policy_logits = prediction
(
current_value_loss,
current_reward_loss,
current_policy_loss,
) = loss_function(
value.squeeze(-1),
reward.squeeze(-1),
policy_logits,
target_value[:, i],
target_reward[:, i],
target_policy[:, i, :],
)
value_loss += current_value_loss
reward_loss += current_reward_loss
policy_loss += current_policy_loss
# Scale gradient by number of unroll steps (See paper Training appendix)
loss = (
value_loss + reward_loss + policy_loss
).mean() / self.config.num_unroll_steps
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.training_step += 1
return (
loss.item(),
value_loss.mean().item(),
reward_loss.mean().item(),
policy_loss.mean().item(),
)
def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy
):
# TODO: paper promotes cross entropy instead of MSE
value_loss = torch.nn.MSELoss(reduction="none")(value, target_value)
reward_loss = torch.nn.MSELoss(reduction="none")(reward, target_reward)
policy_loss = -(torch.log_softmax(policy_logits, dim=1) * target_policy).sum(1)
return value_loss, reward_loss, policy_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment