Commit 06564837 by Werner Duvaud

Refactored

parent d8388353
......@@ -7,17 +7,27 @@
# MuZero General
A flexible, commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v1/anc/pseudocode.py).
It is designed to be easily adaptable for every games or reinforcement learning environnements (like [gym](https://github.com/openai/gym)). You only need to edit the game file with the parameters and the game class. Please refer to the documentation and the tutorial.
It is designed to be easily adaptable for every games or reinforcement learning environments (like [gym](https://github.com/openai/gym)). You only need to edit the game file with the parameters and the game class. Please refer to the documentation and the [example](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py).
MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games whithout knowing the rules. It only know actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122).
MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games without knowing the rules. It only knows actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122).
It uses [PyTorch](https://github.com/pytorch/pytorch) and [Ray](https://github.com/ray-project/ray) for running the different components simultaneously. There is a complete GPU support.
There are four "actors" which are classes that run simultaneously in a dedicated thread.
The shared storage holds the latest neural network weights, the self-play uses those weights to generate self-play games and store them in the replay buffer. Finally, those games are used to train a network and store the weights in the shared storage. The circle is complete.
Those components are launched and managed from the MuZero class in muzero.py and the structure of the neural network is defined in models.py.
All performances are tracked and displayed in real time in tensorboard.
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/tree/master/pretrained/pretrained/cartpole_training_summary.png)
It uses [PyTorch](https://github.com/pytorch/pytorch) and [Ray](https://github.com/ray-project/ray) for self-playing on multiple threads. A synchronous mode (easier for debug) will be released. There is a complete GPU support.
The code has three parts, muzero.py with the entry class, self-play.py with the replay-buffer and the MCTS classes, and network.py with the neural networks and the shared storage classes.
## Games already implemented with pretrained network available
* Lunar Lander
* Cartpole
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/tree/master/games/lunarlander_training_preview.png)
## Getting started
### Installation
```bash
......@@ -26,32 +36,35 @@ pip install -r requirements.txt
```
### Training
Edit the end of muzero.py :
Edit the end of muzero.py:
```python
muzero = Muzero("cartpole")
muzero.train()
```
Then run :
Then run:
```bash
python muzero.py
```
To visualize the training results, run in a new bash:
```bash
tensorboard --logdir ./
```
### Testing
Edit the end of muzero.py :
Edit the end of muzero.py:
```python
muzero = Muzero("cartpole")
muzero.load_model()
muzero.test()
```
Then run :
Then run:
```bash
python muzero.py
```
## Coming soon
* [ ] Convolutionnal / Atari mode
* [ ] Performance tracking
* [ ] Synchronous mode
* [ ] Atari mode with residual network
* [ ] Live test policy & value tracking
* [ ] [Open spiel](https://github.com/deepmind/open_spiel) integration
* [ ] Checkers game
* [ ] TensorFlow mode
......
import gym
import numpy
import torch
class MuZeroConfig:
......@@ -31,20 +30,18 @@ class MuZeroConfig:
self.max_known_bound = None
### Network
self.encoding_size = 32
self.hidden_size = 64
self.encoding_size = 64
self.hidden_size = 32
# Training
self.results_path = "./pretrained" # Path to store the model weights
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Automatically use GPU instead of CPU if available
self.training_steps = 400 # Total number of training steps (ie weights update according to a batch)
self.training_steps = 1000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.test_interval = 20 # Number of training steps before evaluating the network on the game to track the performance
self.test_episodes = 2 # Number of game played to evaluate the network
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in memory (in the replay buffer)
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9
......@@ -54,7 +51,7 @@ class MuZeroConfig:
self.lr_decay_rate = 0.1
self.lr_decay_steps = 3500
def visit_softmax_temperature_fn(self, num_moves, trained_steps):
def visit_softmax_temperature_fn(self, trained_steps):
"""
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
The smaller it is, the more likely the best action (ie with the highest visit count) is chosen.
......
......@@ -13,7 +13,7 @@ class MuZeroConfig:
### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 100 # Maximum number of moves if game is not finished before
self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
......@@ -31,30 +31,28 @@ class MuZeroConfig:
self.max_known_bound = None
### Network
self.encoding_size = 16
self.hidden_size = 8
self.encoding_size = 64
self.hidden_size = 32
# Training
self.results_path = "./pretrained" # Path to store the model weights
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Automatically use GPU instead of CPU if available
self.training_steps = 700 # Total number of training steps (ie weights update according to a batch)
self.training_steps = 20000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 50 # Number of game moves to keep for every batch element
self.test_interval = 20 # Number of training steps before evaluating the network on the game to track the performance
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.test_episodes = 2 # Number of game played to evaluate the network
self.checkpoint_interval = 20 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in memory (in the replay buffer)
self.td_steps = 50 # Number of steps in the futur to take into account for calculating the target value
self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 100 # Number of steps in the futur to take into account for calculating the target value
self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9
# Exponential learning rate schedule
self.lr_init = 0.005 # Initial learning rate
self.lr_decay_rate = 0.01
self.lr_decay_rate = 0.1
self.lr_decay_steps = 3500
def visit_softmax_temperature_fn(self, num_moves, trained_steps):
def visit_softmax_temperature_fn(self, trained_steps):
"""
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
The smaller it is, the more likely the best action (ie with the highest visit count) is chosen.
......@@ -62,14 +60,12 @@ class MuZeroConfig:
Returns:
Positive float.
"""
if trained_steps < 0.25 * self.training_steps:
return 1000
elif trained_steps < 0.5 * self.training_steps:
return 1
if trained_steps < 0.5 * self.training_steps:
return 1.0
elif trained_steps < 0.75 * self.training_steps:
return 0.5
else:
return 0.1
return 0.25
class Game:
......
import torch
class FullyConnectedNetwork(torch.nn.Module):
def __init__(
self, input_size, layers_sizes, output_size, activation=torch.nn.Tanh()
):
super(FullyConnectedNetwork, self).__init__()
layers_sizes.insert(0, input_size)
layers = []
if 1 < len(layers_sizes):
for i in range(len(layers_sizes) - 1):
layers.extend(
[
torch.nn.Linear(layers_sizes[i], layers_sizes[i + 1]),
torch.nn.ReLU(),
]
)
layers.append(torch.nn.Linear(layers_sizes[-1], output_size))
if activation:
layers.append(activation)
self.layers = torch.nn.ModuleList(layers)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
# TODO: unified residual network
class MuZeroNetwork(torch.nn.Module):
def __init__(self, observation_size, action_space_size, encoding_size, hidden_size):
super().__init__()
self.action_space_size = action_space_size
self.representation_network = FullyConnectedNetwork(
observation_size, [], encoding_size
)
self.dynamics_encoded_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], encoding_size
)
self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], 1
)
self.prediction_policy_network = FullyConnectedNetwork(
encoding_size, [], self.action_space_size, activation=None
)
self.prediction_value_network = FullyConnectedNetwork(
encoding_size, [], 1, activation=None
)
def prediction(self, encoded_state):
policy_logit = self.prediction_policy_network(encoded_state)
value = self.prediction_value_network(encoded_state)
return policy_logit, value
def representation(self, observation):
return self.representation_network(observation)
def dynamics(self, encoded_state, action):
action_one_hot = (
torch.zeros((action.shape[0], self.action_space_size))
.to(action.device)
.float()
)
action_one_hot.scatter_(1, action.long(), 1.0)
x = torch.cat((encoded_state, action_one_hot), dim=1)
next_encoded_state = self.dynamics_encoded_state_network(x)
reward = self.dynamics_reward_network(x)
return next_encoded_state, reward
def initial_inference(self, observation):
encoded_state = self.representation(observation)
policy_logit, value = self.prediction(encoded_state)
return (
value,
torch.zeros(len(observation)).to(observation.device),
policy_logit,
encoded_state,
)
def recurrent_inference(self, encoded_state, action):
next_encoded_state, reward = self.dynamics(encoded_state, action)
policy_logit, value = self.prediction(next_encoded_state)
return value, reward, policy_logit, next_encoded_state
def get_weights(self):
return {key: value.cpu() for key, value in self.state_dict().items()}
def set_weights(self, weights):
self.load_state_dict(weights)
\ No newline at end of file
import copy
import datetime
import importlib
import os
import time
......@@ -5,9 +7,13 @@ import time
import numpy
import ray
import torch
from torch.utils.tensorboard import SummaryWriter
import network
import models
import replay_buffer
import self_play
import shared_storage
import trainer
class MuZero:
......@@ -44,144 +50,123 @@ class MuZero:
numpy.random.seed(self.config.seed)
torch.manual_seed(self.config.seed)
self.best_model = network.Network(
# Used to initialize components when continuing a former training
self.muzero_weights = models.MuZeroNetwork(
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
)
).get_weights()
self.training_steps = 0
def train(self):
# Initialize and launch components that work simultaneously
ray.init()
model = self.best_model
model.train()
storage = network.SharedStorage.remote(model)
replay_buffer = self_play.ReplayBuffer.remote(self.config)
for seed in range(self.config.num_actors):
self_play.run_selfplay.remote(
self.Game,
writer = SummaryWriter(
os.path.join(self.config.results_path, self.game_name + "_summary")
)
# Initialize workers
training_worker = trainer.Trainer.remote(
copy.deepcopy(self.muzero_weights),
self.training_steps,
self.config,
# Train on GPU if available
"cuda" if torch.cuda.is_available() else "cpu",
)
shared_storage_worker = shared_storage.SharedStorage.remote(
copy.deepcopy(self.muzero_weights),
self.training_steps,
self.game_name,
self.config,
)
replay_buffer_worker = replay_buffer.ReplayBuffer.remote(self.config)
self_play_workers = [
self_play.SelfPlay.remote(
copy.deepcopy(self.muzero_weights),
self.Game(self.config.seed + seed),
self.config,
storage,
replay_buffer,
model,
self.config.seed + seed,
"cpu",
)
# Initialize network for training
model = model.to(self.config.device)
optimizer = torch.optim.SGD(
model.parameters(),
lr=self.config.lr_init,
momentum=self.config.momentum,
weight_decay=self.config.weight_decay,
for seed in range(self.config.num_actors)
]
test_worker = self_play.SelfPlay.remote(
copy.deepcopy(self.muzero_weights), self.Game(), self.config, "cpu",
)
# Wait for replay buffer to be non-empty
while ray.get(replay_buffer.length.remote()) == 0:
time.sleep(0.1)
# Training loop
best_test_rewards = None
for training_step in range(self.config.training_steps):
model.train()
storage.set_training_step.remote(training_step)
# Make the model available for self-play
if training_step % self.config.checkpoint_interval == 0:
storage.set_weights.remote(model.state_dict())
# Update learning rate
lr = self.config.lr_init * self.config.lr_decay_rate ** (
training_step / self.config.lr_decay_steps
# Launch workers
[
self_play_worker.continuous_self_play.remote(
shared_storage_worker, replay_buffer_worker
)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
for self_play_worker in self_play_workers
]
test_worker.continuous_self_play.remote(shared_storage_worker, None, True)
training_worker.continuous_update_weights.remote(
replay_buffer_worker, shared_storage_worker
)
# Train on a batch.
batch = ray.get(
replay_buffer.sample_batch.remote(
self.config.num_unroll_steps, self.config.td_steps
)
# Loop for monitoring in real time the workers
print(
"Run tensorboard --logdir ./ and go to http://localhost:6006/ to track the training performance"
)
counter = 0
infos = ray.get(shared_storage_worker.get_infos.remote())
while infos["training_step"] < self.config.training_steps:
# Get and save real time performance
infos = ray.get(shared_storage_worker.get_infos.remote())
writer.add_scalar(
"1.Total reward/Total reward", infos["total_reward"], counter
)
loss = network.update_weights(optimizer, model, batch, self.config)
# Test the current model and save it on disk if it is the best
if training_step % self.config.test_interval == 0:
total_test_rewards = self.test(model=model, render=False)
if best_test_rewards is None or sum(total_test_rewards) >= sum(
best_test_rewards
):
self.best_model = model
best_test_rewards = total_test_rewards
self.save_model()
print(
"Training step: {}\nBuffer Size: {}\nLearning rate: {}\nLoss: {}\nLast test score: {}\nBest sest score: {}\n".format(
training_step,
ray.get(replay_buffer.length.remote()),
lr,
loss,
str(total_test_rewards),
str(best_test_rewards),
)
writer.add_scalar(
"2.Workers/Self played games",
ray.get(replay_buffer_worker.get_self_play_count.remote()),
counter,
)
# Finally, save the latest network in the shared storage and end the self-play
storage.set_weights.remote(model.state_dict())
writer.add_scalar(
"2.Workers/Training steps", infos["training_step"], counter
)
writer.add_scalar("3.Loss/1.Total loss", infos["total_loss"], counter)
writer.add_scalar("3.Loss details/Value loss", infos["value_loss"], counter)
writer.add_scalar(
"3.Loss details/Reward loss", infos["reward_loss"], counter
)
writer.add_scalar(
"3.Loss details/Policy loss", infos["policy_loss"], counter
)
counter += 1
time.sleep(1)
self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
ray.shutdown()
def test(self, model=None, render=True):
if not model:
model = self.best_model
model.to(self.config.device)
def test(self, render=True):
"""
Test the model in a dedicated thread.
"""
ray.init()
self_play_workers = self_play.SelfPlay.remote(
copy.deepcopy(self.muzero_weights), self.Game(), self.config, "cpu",
)
test_rewards = []
game = self.Game()
model.eval()
with torch.no_grad():
for _ in range(self.config.test_episodes):
observation = game.reset()
done = False
total_reward = 0
while not done:
if render:
game.render()
root = self_play.MCTS(self.config).run(model, observation, False)
action = self_play.select_action(root, temperature=0)
observation, reward, done = game.step(action)
total_reward += reward
test_rewards.append(total_reward)
history = ray.get(self_play_workers.self_play.remote(0, render))
test_rewards.append(sum(history.rewards))
ray.shutdown()
return test_rewards
def save_model(self, model=None, path=None):
if not model:
model = self.best_model
def load_model(self, path=None, training_step=0):
# TODO: why pretrained model is degradated during the new train
if not path:
path = os.path.join(self.config.results_path, self.game_name)
torch.save(model.state_dict(), path)
def load_model(self, path=None):
if not path:
path = os.path.join(self.config.results_path, self.game_name)
self.best_model = network.Network(
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
)
try:
self.best_model.load_state_dict(torch.load(path))
self.muzero_weights = torch.load(path)
self.training_step = training_step
except FileNotFoundError:
print("There is no model saved in {}.".format(path))
if __name__ == "__main__":
# Load the game and the parameters from ./games/file_name.py
muzero = MuZero("cartpole")
muzero.load_model()
muzero.train()
# muzero.load_model()
muzero.test()
import ray
import torch
class Network(torch.nn.Module):
def __init__(self, input_size, action_space_n, encoding_size, hidden_size):
super().__init__()
self.action_space_n = action_space_n
self.representation_network = FullyConnectedNetwork(
input_size, [], encoding_size
)
self.dynamics_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_n, [hidden_size], encoding_size
)
self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_n, [hidden_size], 1
)
self.prediction_actor_network = FullyConnectedNetwork(
encoding_size, [], self.action_space_n, activation=None
)
self.prediction_value_network = FullyConnectedNetwork(
encoding_size, [], 1, activation=None
)
def prediction(self, state):
actor_logit = self.prediction_actor_network(state)
value = self.prediction_value_network(state)
return actor_logit, value
def representation(self, observation):
return self.representation_network(observation)
def dynamics(self, state, action):
action_one_hot = (
torch.zeros((action.shape[0], self.action_space_n))
.to(action.device)
.float()
)
action_one_hot.scatter_(1, action.long(), 1.0)
x = torch.cat((state, action_one_hot), dim=1)
next_state = self.dynamics_state_network(x)
reward = self.dynamics_reward_network(x)
return next_state, reward
def initial_inference(self, observation):
state = self.representation(observation)
actor_logit, value = self.prediction(state)
return (
value,
torch.zeros(len(observation)).to(observation.device),
actor_logit,
state,
)
def recurrent_inference(self, hidden_state, action):
state, reward = self.dynamics(hidden_state, action)
actor_logit, value = self.prediction(state)
return value, reward, actor_logit, state
def update_weights(optimizer, model, batch, config):
observation_batch, action_batch, target_reward, target_value, target_policy = batch
observation_batch = torch.tensor(observation_batch).float().to(config.device)
action_batch = torch.tensor(action_batch).float().to(config.device).unsqueeze(-1)
target_value = torch.tensor(target_value).float().to(config.device)
target_reward = torch.tensor(target_reward).float().to(config.device)
target_policy = torch.tensor(target_policy).float().to(config.device)
value, reward, policy_logits, hidden_state = model.initial_inference(
observation_batch
)
predictions = [(value, reward, policy_logits)]
for action_i in range(config.num_unroll_steps):
value, reward, policy_logits, hidden_state = model.recurrent_inference(
hidden_state, action_batch[:, action_i]
)
predictions.append((value, reward, policy_logits))
loss = 0
for i, prediction in enumerate(predictions):
value, reward, policy_logits = prediction
loss += loss_function(
value.squeeze(-1),
reward.squeeze(-1),
policy_logits,
target_value[:, i],
target_reward[:, i],
target_policy[:, i, :],
)
# Scale gradient by number of unroll steps (See paper Training appendix)
loss = loss.mean() / config.num_unroll_steps
# Optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy
):
# TODO: paper promotes cross entropy instead of MSE
value_loss = torch.nn.MSELoss(reduction="none")(value, target_value)
reward_loss = torch.nn.MSELoss(reduction="none")(reward, target_reward)
policy_loss = -(torch.log_softmax(policy_logits, dim=1) * target_policy).sum(1)
return value_loss + reward_loss + policy_loss
class FullyConnectedNetwork(torch.nn.Module):
def __init__(
self, input_size, layers_sizes, output_size, activation=torch.nn.Tanh()
):
super(FullyConnectedNetwork, self).__init__()
layers_sizes.insert(0, input_size)
layers = []
if 1 < len(layers_sizes):
for i in range(len(layers_sizes) - 1):
layers.extend(
[
torch.nn.Linear(layers_sizes[i], layers_sizes[i + 1]),
torch.nn.ReLU(),
]
)
layers.append(torch.nn.Linear(layers_sizes[-1], output_size))
if activation:
layers.append(activation)
self.layers = torch.nn.ModuleList(layers)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
@ray.remote
class SharedStorage:
def __init__(self, model):
self.training_step = 0
self.model = model
def get_weights(self):
return self.model.state_dict()
def set_weights(self, weights):
return self.model.load_state_dict(weights)
def set_training_step(self, training_step):
self.training_step = training_step
def get_training_step(self):
return self.training_step
{
"nbformat": 4,
"nbformat_minor": 2,
"metadata": {
"language_info": {
"name": "python",
"codemirror_mode": {
"name": "ipython",
"version": 3
}
},
"orig_nbformat": 2,
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"npconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": 3
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Google colab imports\n",
"!pip install -r requirements.txt\n",
"!pip uninstall -y pyarrow\n",
"# If you have an import issue with ray, restart the environment (execution menu)\n",
"\n",
"# You must have the repository imported along with your notebook. \n",
"# For google colab, click on \">\" buttton (left) and import files (muzero.py, self_play.py, ...).\n",
"import muzero as mz"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Train on cartpole game\n",
"muzero = mz.MuZero(\"cartpole\")\n",
"muzero.load_model()\n",
"muzero.train()\n",
"muzero.test()"
]
}
]
}
\ No newline at end of file
No preview for this file type
No preview for this file type
import numpy
import ray
@ray.remote
class ReplayBuffer:
"""
Class which run in a dedicated thread to store played games and generate batch.
"""
def __init__(self, config):
self.config = config
self.buffer = []
self.self_play_count = 0
def save_game(self, game_history):
if len(self.buffer) > self.config.window_size:
self.buffer.pop(0)
self.buffer.append(game_history)
self.self_play_count += 1
def get_self_play_count(self):
return self.self_play_count
def get_batch(self):
observation_batch, action_batch, reward_batch, value_batch, policy_batch = (
[],
[],
[],
[],
[],
)
for _ in range(self.config.batch_size):
game_history = sample_game(self.buffer)
game_pos = sample_position(game_history)
actions = game_history.history[
game_pos : game_pos + self.config.num_unroll_steps
]
# Repeat precedent action to make "actions" of length "num_unroll_steps"
actions.extend(
[
actions[-1]
for _ in range(self.config.num_unroll_steps - len(actions) + 1)
]
)
observation_batch.append(game_history.observation_history[game_pos])
action_batch.append(actions)
value, reward, policy = make_target(
game_history,
game_pos,
self.config.num_unroll_steps,
self.config.td_steps,
)
value_batch.append(value)
reward_batch.append(reward)
policy_batch.append(policy)
return observation_batch, action_batch, value_batch, reward_batch, policy_batch
def sample_game(buffer):
"""
Sample game from buffer either uniformly or according to some priority.
"""
# TODO: sample with probability link to the highest difference between real and predicted value (see paper appendix Training)
return numpy.random.choice(buffer)
def sample_position(game_history):
"""
Sample position from game either uniformly or according to some priority.
"""
# TODO: according to some priority
return numpy.random.choice(range(len(game_history.rewards)))
def make_target(game_history, state_index, num_unroll_steps, td_steps):
"""
The value target is the discounted root value of the search tree td_steps into the future, plus the discounted sum of all rewards until then.
"""
target_values, target_rewards, target_policies = [], [], []
for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps
if bootstrap_index < len(game_history.root_values):
value = (
game_history.root_values[bootstrap_index]
* game_history.discount ** td_steps
)
else:
value = 0
for i, reward in enumerate(game_history.rewards[current_index:bootstrap_index]):
value += reward * game_history.discount ** i
if current_index < len(game_history.root_values):
target_values.append(value)
target_rewards.append(game_history.rewards[current_index])
target_policies.append(game_history.child_visits[current_index])
else:
# States past the end of games are treated as absorbing states
target_values.append(0)
target_rewards.append(0)
# Uniform policy to give the tensor a valid dimension
target_policies.append(
[
1 / len(game_history.child_visits[0])
for _ in range(len(game_history.child_visits[0]))
]
)
return target_values, target_rewards, target_policies
......@@ -4,43 +4,106 @@ import numpy
import ray
import torch
import models
@ray.remote # (num_gpus=1) # Uncoomment num_gpus and model.to(config.device) to self-play on GPU
def run_selfplay(Game, config, shared_storage, replay_buffer, model, seed):
@ray.remote
class SelfPlay:
"""
Function which run simultaneously in multiple threads and is continuously playing games and saving them to the replay-buffer.
Class which run in a dedicated thread to play games and save them to the replay-buffer.
"""
# model.to(config.device) # Uncoomment this line and num_gpus from the decorator to self-play on GPU
with torch.no_grad():
while True:
# Initialize a self-play
model.load_state_dict(ray.get(shared_storage.get_weights.remote()))
game = Game(seed)
done = False
game_history = GameHistory(config.discount)
# Self-play with actions based on the Monte Carlo tree search at each moves
observation = game.reset()
def __init__(self, initial_weights, game, config, device):
self.config = config
self.game = game
# Initialize the network
self.model = models.MuZeroNetwork(
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
)
self.model.set_weights(initial_weights)
self.model.to(torch.device(device))
def continuous_self_play(self, shared_storage, replay_buffer, test_mode=False):
with torch.no_grad():
while True:
self.model.set_weights(ray.get(shared_storage.get_weights.remote()))
# Take the best action (no exploration) in test mode
temperature = (
0
if test_mode
else self.config.visit_softmax_temperature_fn(
trained_steps=ray.get(shared_storage.get_infos.remote())[
"training_step"
]
)
)
game_history = self.self_play(temperature, False)
# Save to the shared storage
if test_mode:
shared_storage.set_infos.remote(
"total_reward", sum(game_history.rewards)
)
if not test_mode:
replay_buffer.save_game.remote(game_history)
def self_play(self, temperature, render):
"""
Play one game with actions based on the Monte Carlo tree search at each moves.
"""
game_history = GameHistory(self.config.discount)
observation = self.game.reset()
game_history.observation_history.append(observation)
done = False
while not done and len(game_history.history) < self.config.max_moves:
root = MCTS(self.config).run(self.model, observation, True)
action = select_action(root, temperature)
observation, reward, done = self.game.step(action)
if render:
self.game.render()
print("Press enter to step")
game_history.observation_history.append(observation)
while not done and len(game_history.history) < config.max_moves:
root = MCTS(config).run(model, observation, True)
game_history.rewards.append(reward)
game_history.history.append(action)
game_history.store_search_statistics(root, self.config.action_space)
temperature = config.visit_softmax_temperature_fn(
num_moves=len(game_history.history),
trained_steps=ray.get(shared_storage.get_training_step.remote()),
)
action = select_action(root, temperature)
self.game.close()
return game_history
observation, reward, done = game.step(action)
game_history.observation_history.append(observation)
game_history.rewards.append(reward)
game_history.history.append(action)
game_history.store_search_statistics(root, config.action_space)
def select_action(node, temperature):
"""
Select action according to the vivist count distribution and the temperature.
The temperature is changed dynamically with the visit_softmax_temperature function in the config.
"""
visit_counts = numpy.array(
[[child.visit_count, action] for action, child in node.children.items()]
).T
if temperature == 0:
action_pos = numpy.argmax(visit_counts[0])
else:
# See paper Data Generation appendix
visit_count_distribution = visit_counts[0] ** (1 / temperature)
visit_count_distribution = visit_count_distribution / sum(
visit_count_distribution
)
action_pos = numpy.random.choice(
len(visit_counts[1]), p=visit_count_distribution
)
game.close()
# Save the game history
replay_buffer.save_game.remote(game_history)
if temperature == float("inf"):
action_pos = numpy.random.choice(len(visit_counts[1]))
return visit_counts[1][action_pos]
# Game independant
......@@ -62,7 +125,10 @@ class MCTS:
"""
root = Node(0)
observation = (
torch.from_numpy(observation).to(self.config.device).float().unsqueeze(0)
torch.from_numpy(observation)
.float()
.unsqueeze(0)
.to(next(model.parameters()).device)
)
_, expected_reward, policy_logits, hidden_state = model.initial_inference(
observation
......@@ -76,9 +142,7 @@ class MCTS:
exploration_fraction=self.config.root_exploration_fraction,
)
min_max_stats = MinMaxStats(
self.config.min_known_bound, self.config.max_known_bound
)
min_max_stats = MinMaxStats()
for _ in range(self.config.num_simulations):
node = root
......@@ -141,32 +205,6 @@ class MCTS:
value = node.reward + self.config.discount * value
def select_action(node, temperature, random=False):
"""
Select action according to the vivist count distribution and the temperature.
The temperature is changed dynamically with the visit_softmax_temperature function in the config.
"""
visit_counts = numpy.array(
[[child.visit_count, action] for action, child in node.children.items()]
).T
if temperature == 0:
action_pos = numpy.argmax(visit_counts[0])
else:
# See paper Data Generation appendix
visit_count_distribution = visit_counts[0] ** (1 / temperature)
visit_count_distribution = visit_count_distribution / sum(
visit_count_distribution
)
action_pos = numpy.random.choice(
len(visit_counts[1]), p=visit_count_distribution
)
if random:
action_pos = numpy.random.choice(len(visit_counts[1]))
return visit_counts[1][action_pos]
class Node:
def __init__(self, prior):
self.visit_count = 0
......@@ -231,105 +269,14 @@ class GameHistory:
self.root_values.append(root.value())
@ray.remote
class ReplayBuffer:
# Store list of game history and generate batch
def __init__(self, config):
self.window_size = config.window_size
self.batch_size = config.batch_size
self.buffer = []
def save_game(self, game_history):
if len(self.buffer) > self.window_size:
self.buffer.pop(0)
self.buffer.append(game_history)
def sample_batch(self, num_unroll_steps, td_steps):
observation_batch, action_batch, reward_batch, value_batch, policy_batch = (
[],
[],
[],
[],
[],
)
for _ in range(self.batch_size):
game_history = self.sample_game()
game_pos = self.sample_position(game_history)
actions = game_history.history[game_pos : game_pos + num_unroll_steps]
# Repeat precedent action to make "actions" of length "num_unroll_steps"
actions.extend(
[actions[-1] for _ in range(num_unroll_steps - len(actions) + 1)]
)
observation_batch.append(game_history.observation_history[game_pos])
action_batch.append(actions)
value, reward, policy = self.make_target(
game_history, game_pos, num_unroll_steps, td_steps
)
reward_batch.append(reward)
value_batch.append(value)
policy_batch.append(policy)
return observation_batch, action_batch, reward_batch, value_batch, policy_batch
def sample_game(self):
"""
Sample game from buffer either uniformly or according to some priority.
"""
# TODO: sample with probability link to the highest difference between real and predicted value (see paper appendix Training)
return self.buffer[numpy.random.choice(range(len(self.buffer)))]
def sample_position(self, game):
"""
Sample position from game either uniformly or according to some priority.
"""
# TODO: according to some priority
return numpy.random.choice(range(len(game.rewards)))
def make_target(self, game, state_index, num_unroll_steps, td_steps):
"""
The value target is the discounted root value of the search tree td_steps into the future, plus the discounted sum of all rewards until then.
"""
target_values, target_rewards, target_policies = [], [], []
for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps
if bootstrap_index < len(game.root_values):
value = game.root_values[bootstrap_index] * game.discount ** td_steps
else:
value = 0
for i, reward in enumerate(game.rewards[current_index:bootstrap_index]):
value += reward * game.discount ** i
if current_index < len(game.root_values):
target_values.append(value)
target_rewards.append(game.rewards[current_index])
target_policies.append(game.child_visits[current_index])
else:
# States past the end of games are treated as absorbing states
target_values.append(0)
target_rewards.append(0)
# Uniform policy to give the tensor a valid dimension
target_policies.append(
[
1 / len(game.child_visits[0])
for _ in range(len(game.child_visits[0]))
]
)
return target_values, target_rewards, target_policies
def length(self):
return len(self.buffer)
class MinMaxStats:
"""
A class that holds the min-max values of the tree.
"""
def __init__(self, min_value_bound, max_value_bound):
self.maximum = min_value_bound if min_value_bound else -float("inf")
self.minimum = max_value_bound if max_value_bound else float("inf")
def __init__(self):
self.maximum = -float("inf")
self.minimum = float("inf")
def update(self, value):
self.maximum = max(self.maximum, value)
......
import ray
import torch
import os
@ray.remote
class SharedStorage:
"""
Class which run in a dedicated thread to store the network weights and some information.
"""
def __init__(self, weights, training_step, game_name, config):
self.config = config
self.game_name = game_name
self.weights = weights
self.infos = {'training_step': training_step,
'total_reward': 0,
'total_loss': 0,
'value_loss': 0,
'reward_loss': 0,
'policy_loss': 0}
def get_weights(self):
return self.weights
def set_weights(self, weights, path=None):
self.weights = weights
if not path:
path = os.path.join(self.config.results_path, self.game_name)
torch.save(self.weights, path)
def get_infos(self):
return self.infos
def set_infos(self, key, value):
self.infos[key] = value
import time
import numpy
import ray
import torch
import models
@ray.remote(num_gpus=1)
class Trainer:
"""
Class which run in a dedicated thread to train a neural network and save it in the shared storage.
"""
def __init__(self, initial_weights, initial_training_step, config, device):
self.config = config
self.training_step = initial_training_step
# Initialize the network
self.model = models.MuZeroNetwork(
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
)
self.model.set_weights(initial_weights)
self.model.to(torch.device(device))
self.model.train()
self.optimizer = torch.optim.SGD(
self.model.parameters(),
lr=self.config.lr_init,
momentum=self.config.momentum,
weight_decay=self.config.weight_decay,
)
def continuous_update_weights(self, replay_buffer, shared_storage_worker):
# Wait for the replay buffer to be filled
while ray.get(replay_buffer.get_self_play_count.remote()) < 1:
time.sleep(0.1)
# Training loop
while True:
batch = ray.get(replay_buffer.get_batch.remote())
total_loss, value_loss, reward_loss, policy_loss = self.update_weights(
batch
)
# Save to the shared storage
if self.training_step % self.config.checkpoint_interval == 0:
shared_storage_worker.set_weights.remote(self.model.get_weights())
shared_storage_worker.set_infos.remote("training_step", self.training_step)
shared_storage_worker.set_infos.remote("total_loss", total_loss)
shared_storage_worker.set_infos.remote("value_loss", value_loss)
shared_storage_worker.set_infos.remote("reward_loss", reward_loss)
shared_storage_worker.set_infos.remote("policy_loss", policy_loss)
def update_weights(self, batch):
"""
Perform one training step.
"""
# Update learning rate
lr = self.config.lr_init * self.config.lr_decay_rate ** (
self.training_step / self.config.lr_decay_steps
)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
(
observation_batch,
action_batch,
target_value,
target_reward,
target_policy,
) = batch
device = next(self.model.parameters()).device
observation_batch = torch.tensor(observation_batch).float().to(device)
action_batch = torch.tensor(action_batch).float().to(device).unsqueeze(-1)
target_value = torch.tensor(target_value).float().to(device)
target_reward = torch.tensor(target_reward).float().to(device)
target_policy = torch.tensor(target_policy).float().to(device)
value, reward, policy_logits, hidden_state = self.model.initial_inference(
observation_batch
)
predictions = [(value, reward, policy_logits)]
for action_i in range(self.config.num_unroll_steps):
value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
hidden_state, action_batch[:, action_i]
)
predictions.append((value, reward, policy_logits))
# Compute losses
value_loss, reward_loss, policy_loss = (0, 0, 0)
for i, prediction in enumerate(predictions):
value, reward, policy_logits = prediction
(
current_value_loss,
current_reward_loss,
current_policy_loss,
) = loss_function(
value.squeeze(-1),
reward.squeeze(-1),
policy_logits,
target_value[:, i],
target_reward[:, i],
target_policy[:, i, :],
)
value_loss += current_value_loss
reward_loss += current_reward_loss
policy_loss += current_policy_loss
# Scale gradient by number of unroll steps (See paper Training appendix)
loss = (
value_loss + reward_loss + policy_loss
).mean() / self.config.num_unroll_steps
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.training_step += 1
return (
loss.item(),
value_loss.mean().item(),
reward_loss.mean().item(),
policy_loss.mean().item(),
)
def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy
):
# TODO: paper promotes cross entropy instead of MSE
value_loss = torch.nn.MSELoss(reduction="none")(value, target_value)
reward_loss = torch.nn.MSELoss(reduction="none")(reward, target_reward)
policy_loss = -(torch.log_softmax(policy_logits, dim=1) * target_policy).sum(1)
return value_loss, reward_loss, policy_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment