Commit 06564837 by Werner Duvaud

Refactored

parent d8388353
......@@ -7,17 +7,27 @@
# MuZero General
A flexible, commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v1/anc/pseudocode.py).
It is designed to be easily adaptable for every games or reinforcement learning environnements (like [gym](https://github.com/openai/gym)). You only need to edit the game file with the parameters and the game class. Please refer to the documentation and the tutorial.
It is designed to be easily adaptable for every games or reinforcement learning environments (like [gym](https://github.com/openai/gym)). You only need to edit the game file with the parameters and the game class. Please refer to the documentation and the [example](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py).
MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games whithout knowing the rules. It only know actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122).
MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games without knowing the rules. It only knows actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122).
It uses [PyTorch](https://github.com/pytorch/pytorch) and [Ray](https://github.com/ray-project/ray) for running the different components simultaneously. There is a complete GPU support.
There are four "actors" which are classes that run simultaneously in a dedicated thread.
The shared storage holds the latest neural network weights, the self-play uses those weights to generate self-play games and store them in the replay buffer. Finally, those games are used to train a network and store the weights in the shared storage. The circle is complete.
Those components are launched and managed from the MuZero class in muzero.py and the structure of the neural network is defined in models.py.
All performances are tracked and displayed in real time in tensorboard.
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/tree/master/pretrained/pretrained/cartpole_training_summary.png)
It uses [PyTorch](https://github.com/pytorch/pytorch) and [Ray](https://github.com/ray-project/ray) for self-playing on multiple threads. A synchronous mode (easier for debug) will be released. There is a complete GPU support.
The code has three parts, muzero.py with the entry class, self-play.py with the replay-buffer and the MCTS classes, and network.py with the neural networks and the shared storage classes.
## Games already implemented with pretrained network available
* Lunar Lander
* Cartpole
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/tree/master/games/lunarlander_training_preview.png)
## Getting started
### Installation
```bash
......@@ -26,32 +36,35 @@ pip install -r requirements.txt
```
### Training
Edit the end of muzero.py :
Edit the end of muzero.py:
```python
muzero = Muzero("cartpole")
muzero.train()
```
Then run :
Then run:
```bash
python muzero.py
```
To visualize the training results, run in a new bash:
```bash
tensorboard --logdir ./
```
### Testing
Edit the end of muzero.py :
Edit the end of muzero.py:
```python
muzero = Muzero("cartpole")
muzero.load_model()
muzero.test()
```
Then run :
Then run:
```bash
python muzero.py
```
## Coming soon
* [ ] Convolutionnal / Atari mode
* [ ] Performance tracking
* [ ] Synchronous mode
* [ ] Atari mode with residual network
* [ ] Live test policy & value tracking
* [ ] [Open spiel](https://github.com/deepmind/open_spiel) integration
* [ ] Checkers game
* [ ] TensorFlow mode
......
import gym
import numpy
import torch
class MuZeroConfig:
......@@ -31,20 +30,18 @@ class MuZeroConfig:
self.max_known_bound = None
### Network
self.encoding_size = 32
self.hidden_size = 64
self.encoding_size = 64
self.hidden_size = 32
# Training
self.results_path = "./pretrained" # Path to store the model weights
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Automatically use GPU instead of CPU if available
self.training_steps = 400 # Total number of training steps (ie weights update according to a batch)
self.training_steps = 1000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.test_interval = 20 # Number of training steps before evaluating the network on the game to track the performance
self.test_episodes = 2 # Number of game played to evaluate the network
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in memory (in the replay buffer)
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9
......@@ -54,7 +51,7 @@ class MuZeroConfig:
self.lr_decay_rate = 0.1
self.lr_decay_steps = 3500
def visit_softmax_temperature_fn(self, num_moves, trained_steps):
def visit_softmax_temperature_fn(self, trained_steps):
"""
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
The smaller it is, the more likely the best action (ie with the highest visit count) is chosen.
......
......@@ -13,7 +13,7 @@ class MuZeroConfig:
### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 100 # Maximum number of moves if game is not finished before
self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
......@@ -31,30 +31,28 @@ class MuZeroConfig:
self.max_known_bound = None
### Network
self.encoding_size = 16
self.hidden_size = 8
self.encoding_size = 64
self.hidden_size = 32
# Training
self.results_path = "./pretrained" # Path to store the model weights
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Automatically use GPU instead of CPU if available
self.training_steps = 700 # Total number of training steps (ie weights update according to a batch)
self.training_steps = 20000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 50 # Number of game moves to keep for every batch element
self.test_interval = 20 # Number of training steps before evaluating the network on the game to track the performance
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.test_episodes = 2 # Number of game played to evaluate the network
self.checkpoint_interval = 20 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in memory (in the replay buffer)
self.td_steps = 50 # Number of steps in the futur to take into account for calculating the target value
self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 100 # Number of steps in the futur to take into account for calculating the target value
self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9
# Exponential learning rate schedule
self.lr_init = 0.005 # Initial learning rate
self.lr_decay_rate = 0.01
self.lr_decay_rate = 0.1
self.lr_decay_steps = 3500
def visit_softmax_temperature_fn(self, num_moves, trained_steps):
def visit_softmax_temperature_fn(self, trained_steps):
"""
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
The smaller it is, the more likely the best action (ie with the highest visit count) is chosen.
......@@ -62,14 +60,12 @@ class MuZeroConfig:
Returns:
Positive float.
"""
if trained_steps < 0.25 * self.training_steps:
return 1000
elif trained_steps < 0.5 * self.training_steps:
return 1
if trained_steps < 0.5 * self.training_steps:
return 1.0
elif trained_steps < 0.75 * self.training_steps:
return 0.5
else:
return 0.1
return 0.25
class Game:
......
import torch
class FullyConnectedNetwork(torch.nn.Module):
def __init__(
self, input_size, layers_sizes, output_size, activation=torch.nn.Tanh()
):
super(FullyConnectedNetwork, self).__init__()
layers_sizes.insert(0, input_size)
layers = []
if 1 < len(layers_sizes):
for i in range(len(layers_sizes) - 1):
layers.extend(
[
torch.nn.Linear(layers_sizes[i], layers_sizes[i + 1]),
torch.nn.ReLU(),
]
)
layers.append(torch.nn.Linear(layers_sizes[-1], output_size))
if activation:
layers.append(activation)
self.layers = torch.nn.ModuleList(layers)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
# TODO: unified residual network
class MuZeroNetwork(torch.nn.Module):
def __init__(self, observation_size, action_space_size, encoding_size, hidden_size):
super().__init__()
self.action_space_size = action_space_size
self.representation_network = FullyConnectedNetwork(
observation_size, [], encoding_size
)
self.dynamics_encoded_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], encoding_size
)
self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], 1
)
self.prediction_policy_network = FullyConnectedNetwork(
encoding_size, [], self.action_space_size, activation=None
)
self.prediction_value_network = FullyConnectedNetwork(
encoding_size, [], 1, activation=None
)
def prediction(self, encoded_state):
policy_logit = self.prediction_policy_network(encoded_state)
value = self.prediction_value_network(encoded_state)
return policy_logit, value
def representation(self, observation):
return self.representation_network(observation)
def dynamics(self, encoded_state, action):
action_one_hot = (
torch.zeros((action.shape[0], self.action_space_size))
.to(action.device)
.float()
)
action_one_hot.scatter_(1, action.long(), 1.0)
x = torch.cat((encoded_state, action_one_hot), dim=1)
next_encoded_state = self.dynamics_encoded_state_network(x)
reward = self.dynamics_reward_network(x)
return next_encoded_state, reward
def initial_inference(self, observation):
encoded_state = self.representation(observation)
policy_logit, value = self.prediction(encoded_state)
return (
value,
torch.zeros(len(observation)).to(observation.device),
policy_logit,
encoded_state,
)
def recurrent_inference(self, encoded_state, action):
next_encoded_state, reward = self.dynamics(encoded_state, action)
policy_logit, value = self.prediction(next_encoded_state)
return value, reward, policy_logit, next_encoded_state
def get_weights(self):
return {key: value.cpu() for key, value in self.state_dict().items()}
def set_weights(self, weights):
self.load_state_dict(weights)
\ No newline at end of file
import copy
import datetime
import importlib
import os
import time
......@@ -5,9 +7,13 @@ import time
import numpy
import ray
import torch
from torch.utils.tensorboard import SummaryWriter
import network
import models
import replay_buffer
import self_play
import shared_storage
import trainer
class MuZero:
......@@ -44,144 +50,123 @@ class MuZero:
numpy.random.seed(self.config.seed)
torch.manual_seed(self.config.seed)
self.best_model = network.Network(
# Used to initialize components when continuing a former training
self.muzero_weights = models.MuZeroNetwork(
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
)
).get_weights()
self.training_steps = 0
def train(self):
# Initialize and launch components that work simultaneously
ray.init()
model = self.best_model
model.train()
storage = network.SharedStorage.remote(model)
replay_buffer = self_play.ReplayBuffer.remote(self.config)
for seed in range(self.config.num_actors):
self_play.run_selfplay.remote(
self.Game,
writer = SummaryWriter(
os.path.join(self.config.results_path, self.game_name + "_summary")
)
# Initialize workers
training_worker = trainer.Trainer.remote(
copy.deepcopy(self.muzero_weights),
self.training_steps,
self.config,
# Train on GPU if available
"cuda" if torch.cuda.is_available() else "cpu",
)
shared_storage_worker = shared_storage.SharedStorage.remote(
copy.deepcopy(self.muzero_weights),
self.training_steps,
self.game_name,
self.config,
)
replay_buffer_worker = replay_buffer.ReplayBuffer.remote(self.config)
self_play_workers = [
self_play.SelfPlay.remote(
copy.deepcopy(self.muzero_weights),
self.Game(self.config.seed + seed),
self.config,
storage,
replay_buffer,
model,
self.config.seed + seed,
"cpu",
)
# Initialize network for training
model = model.to(self.config.device)
optimizer = torch.optim.SGD(
model.parameters(),
lr=self.config.lr_init,
momentum=self.config.momentum,
weight_decay=self.config.weight_decay,
for seed in range(self.config.num_actors)
]
test_worker = self_play.SelfPlay.remote(
copy.deepcopy(self.muzero_weights), self.Game(), self.config, "cpu",
)
# Wait for replay buffer to be non-empty
while ray.get(replay_buffer.length.remote()) == 0:
time.sleep(0.1)
# Training loop
best_test_rewards = None
for training_step in range(self.config.training_steps):
model.train()
storage.set_training_step.remote(training_step)
# Make the model available for self-play
if training_step % self.config.checkpoint_interval == 0:
storage.set_weights.remote(model.state_dict())
# Update learning rate
lr = self.config.lr_init * self.config.lr_decay_rate ** (
training_step / self.config.lr_decay_steps
# Launch workers
[
self_play_worker.continuous_self_play.remote(
shared_storage_worker, replay_buffer_worker
)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
for self_play_worker in self_play_workers
]
test_worker.continuous_self_play.remote(shared_storage_worker, None, True)
training_worker.continuous_update_weights.remote(
replay_buffer_worker, shared_storage_worker
)
# Train on a batch.
batch = ray.get(
replay_buffer.sample_batch.remote(
self.config.num_unroll_steps, self.config.td_steps
)
# Loop for monitoring in real time the workers
print(
"Run tensorboard --logdir ./ and go to http://localhost:6006/ to track the training performance"
)
counter = 0
infos = ray.get(shared_storage_worker.get_infos.remote())
while infos["training_step"] < self.config.training_steps:
# Get and save real time performance
infos = ray.get(shared_storage_worker.get_infos.remote())
writer.add_scalar(
"1.Total reward/Total reward", infos["total_reward"], counter
)
loss = network.update_weights(optimizer, model, batch, self.config)
# Test the current model and save it on disk if it is the best
if training_step % self.config.test_interval == 0:
total_test_rewards = self.test(model=model, render=False)
if best_test_rewards is None or sum(total_test_rewards) >= sum(
best_test_rewards
):
self.best_model = model
best_test_rewards = total_test_rewards
self.save_model()
print(
"Training step: {}\nBuffer Size: {}\nLearning rate: {}\nLoss: {}\nLast test score: {}\nBest sest score: {}\n".format(
training_step,
ray.get(replay_buffer.length.remote()),
lr,
loss,
str(total_test_rewards),
str(best_test_rewards),
)
writer.add_scalar(
"2.Workers/Self played games",
ray.get(replay_buffer_worker.get_self_play_count.remote()),
counter,
)
# Finally, save the latest network in the shared storage and end the self-play
storage.set_weights.remote(model.state_dict())
writer.add_scalar(
"2.Workers/Training steps", infos["training_step"], counter
)
writer.add_scalar("3.Loss/1.Total loss", infos["total_loss"], counter)
writer.add_scalar("3.Loss details/Value loss", infos["value_loss"], counter)
writer.add_scalar(
"3.Loss details/Reward loss", infos["reward_loss"], counter
)
writer.add_scalar(
"3.Loss details/Policy loss", infos["policy_loss"], counter
)
counter += 1
time.sleep(1)
self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
ray.shutdown()
def test(self, model=None, render=True):
if not model:
model = self.best_model
model.to(self.config.device)
def test(self, render=True):
"""
Test the model in a dedicated thread.
"""
ray.init()
self_play_workers = self_play.SelfPlay.remote(
copy.deepcopy(self.muzero_weights), self.Game(), self.config, "cpu",
)
test_rewards = []
game = self.Game()
model.eval()
with torch.no_grad():
for _ in range(self.config.test_episodes):
observation = game.reset()
done = False
total_reward = 0
while not done:
if render:
game.render()
root = self_play.MCTS(self.config).run(model, observation, False)
action = self_play.select_action(root, temperature=0)
observation, reward, done = game.step(action)
total_reward += reward
test_rewards.append(total_reward)
history = ray.get(self_play_workers.self_play.remote(0, render))
test_rewards.append(sum(history.rewards))
ray.shutdown()
return test_rewards
def save_model(self, model=None, path=None):
if not model:
model = self.best_model
def load_model(self, path=None, training_step=0):
# TODO: why pretrained model is degradated during the new train
if not path:
path = os.path.join(self.config.results_path, self.game_name)
torch.save(model.state_dict(), path)
def load_model(self, path=None):
if not path:
path = os.path.join(self.config.results_path, self.game_name)
self.best_model = network.Network(
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
)
try:
self.best_model.load_state_dict(torch.load(path))
self.muzero_weights = torch.load(path)
self.training_step = training_step
except FileNotFoundError:
print("There is no model saved in {}.".format(path))
if __name__ == "__main__":
# Load the game and the parameters from ./games/file_name.py
muzero = MuZero("cartpole")
muzero.load_model()
muzero.train()
# muzero.load_model()
muzero.test()
import ray
import torch
class Network(torch.nn.Module):
def __init__(self, input_size, action_space_n, encoding_size, hidden_size):
super().__init__()
self.action_space_n = action_space_n
self.representation_network = FullyConnectedNetwork(
input_size, [], encoding_size
)
self.dynamics_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_n, [hidden_size], encoding_size
)
self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_n, [hidden_size], 1
)
self.prediction_actor_network = FullyConnectedNetwork(
encoding_size, [], self.action_space_n, activation=None
)
self.prediction_value_network = FullyConnectedNetwork(
encoding_size, [], 1, activation=None
)
def prediction(self, state):
actor_logit = self.prediction_actor_network(state)
value = self.prediction_value_network(state)
return actor_logit, value
def representation(self, observation):
return self.representation_network(observation)
def dynamics(self, state, action):
action_one_hot = (
torch.zeros((action.shape[0], self.action_space_n))
.to(action.device)
.float()
)
action_one_hot.scatter_(1, action.long(), 1.0)
x = torch.cat((state, action_one_hot), dim=1)
next_state = self.dynamics_state_network(x)
reward = self.dynamics_reward_network(x)
return next_state, reward
def initial_inference(self, observation):
state = self.representation(observation)
actor_logit, value = self.prediction(state)
return (
value,
torch.zeros(len(observation)).to(observation.device),
actor_logit,
state,
)
def recurrent_inference(self, hidden_state, action):
state, reward = self.dynamics(hidden_state, action)
actor_logit, value = self.prediction(state)
return value, reward, actor_logit, state
def update_weights(optimizer, model, batch, config):
observation_batch, action_batch, target_reward, target_value, target_policy = batch
observation_batch = torch.tensor(observation_batch).float().to(config.device)
action_batch = torch.tensor(action_batch).float().to(config.device).unsqueeze(-1)
target_value = torch.tensor(target_value).float().to(config.device)
target_reward = torch.tensor(target_reward).float().to(config.device)
target_policy = torch.tensor(target_policy).float().to(config.device)
value, reward, policy_logits, hidden_state = model.initial_inference(
observation_batch
)
predictions = [(value, reward, policy_logits)]
for action_i in range(config.num_unroll_steps):
value, reward, policy_logits, hidden_state = model.recurrent_inference(
hidden_state, action_batch[:, action_i]
)
predictions.append((value, reward, policy_logits))
loss = 0
for i, prediction in enumerate(predictions):
value, reward, policy_logits = prediction
loss += loss_function(
value.squeeze(-1),
reward.squeeze(-1),
policy_logits,
target_value[:, i],
target_reward[:, i],
target_policy[:, i, :],
)
# Scale gradient by number of unroll steps (See paper Training appendix)
loss = loss.mean() / config.num_unroll_steps
# Optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy
):
# TODO: paper promotes cross entropy instead of MSE
value_loss = torch.nn.MSELoss(reduction="none")(value, target_value)
reward_loss = torch.nn.MSELoss(reduction="none")(reward, target_reward)
policy_loss = -(torch.log_softmax(policy_logits, dim=1) * target_policy).sum(1)
return value_loss + reward_loss + policy_loss
class FullyConnectedNetwork(torch.nn.Module):
def __init__(
self, input_size, layers_sizes, output_size, activation=torch.nn.Tanh()
):
super(FullyConnectedNetwork, self).__init__()
layers_sizes.insert(0, input_size)
layers = []
if 1 < len(layers_sizes):
for i in range(len(layers_sizes) - 1):
layers.extend(
[
torch.nn.Linear(layers_sizes[i], layers_sizes[i + 1]),
torch.nn.ReLU(),
]
)
layers.append(torch.nn.Linear(layers_sizes[-1], output_size))
if activation:
layers.append(activation)
self.layers = torch.nn.ModuleList(layers)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
@ray.remote
class SharedStorage:
def __init__(self, model):
self.training_step = 0
self.model = model
def get_weights(self):
return self.model.state_dict()
def set_weights(self, weights):
return self.model.load_state_dict(weights)
def set_training_step(self, training_step):
self.training_step = training_step
def get_training_step(self):
return self.training_step
{
"nbformat": 4,
"nbformat_minor": 2,
"metadata": {
"language_info": {
"name": "python",
"codemirror_mode": {
"name": "ipython",
"version": 3
}
},
"orig_nbformat": 2,
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"npconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": 3
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Google colab imports\n",
"!pip install -r requirements.txt\n",
"!pip uninstall -y pyarrow\n",
"# If you have an import issue with ray, restart the environment (execution menu)\n",
"\n",
"# You must have the repository imported along with your notebook. \n",
"# For google colab, click on \">\" buttton (left) and import files (muzero.py, self_play.py, ...).\n",
"import muzero as mz"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Train on cartpole game\n",
"muzero = mz.MuZero(\"cartpole\")\n",
"muzero.load_model()\n",
"muzero.train()\n",
"muzero.test()"
]
}
]
}
\ No newline at end of file
No preview for this file type
No preview for this file type
import numpy
import ray
@ray.remote
class ReplayBuffer:
"""
Class which run in a dedicated thread to store played games and generate batch.
"""
def __init__(self, config):
self.config = config
self.buffer = []
self.self_play_count = 0
def save_game(self, game_history):
if len(self.buffer) > self.config.window_size:
self.buffer.pop(0)
self.buffer.append(game_history)
self.self_play_count += 1
def get_self_play_count(self):
return self.self_play_count
def get_batch(self):
observation_batch, action_batch, reward_batch, value_batch, policy_batch = (
[],
[],
[],
[],
[],
)
for _ in range(self.config.batch_size):
game_history = sample_game(self.buffer)
game_pos = sample_position(game_history)
actions = game_history.history[
game_pos : game_pos + self.config.num_unroll_steps
]
# Repeat precedent action to make "actions" of length "num_unroll_steps"
actions.extend(
[
actions[-1]
for _ in range(self.config.num_unroll_steps - len(actions) + 1)
]
)
observation_batch.append(game_history.observation_history[game_pos])
action_batch.append(actions)
value, reward, policy = make_target(
game_history,
game_pos,
self.config.num_unroll_steps,
self.config.td_steps,
)
value_batch.append(value)
reward_batch.append(reward)
policy_batch.append(policy)
return observation_batch, action_batch, value_batch, reward_batch, policy_batch
def sample_game(buffer):
"""
Sample game from buffer either uniformly or according to some priority.
"""
# TODO: sample with probability link to the highest difference between real and predicted value (see paper appendix Training)
return numpy.random.choice(buffer)
def sample_position(game_history):
"""
Sample position from game either uniformly or according to some priority.
"""
# TODO: according to some priority
return numpy.random.choice(range(len(game_history.rewards)))
def make_target(game_history, state_index, num_unroll_steps, td_steps):
"""
The value target is the discounted root value of the search tree td_steps into the future, plus the discounted sum of all rewards until then.
"""
target_values, target_rewards, target_policies = [], [], []
for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps
if bootstrap_index < len(game_history.root_values):
value = (
game_history.root_values[bootstrap_index]
* game_history.discount ** td_steps
)
else:
value = 0
for i, reward in enumerate(game_history.rewards[current_index:bootstrap_index]):
value += reward * game_history.discount ** i
if current_index < len(game_history.root_values):
target_values.append(value)
target_rewards.append(game_history.rewards[current_index])
target_policies.append(game_history.child_visits[current_index])
else:
# States past the end of games are treated as absorbing states
target_values.append(0)
target_rewards.append(0)
# Uniform policy to give the tensor a valid dimension
target_policies.append(
[
1 / len(game_history.child_visits[0])
for _ in range(len(game_history.child_visits[0]))
]
)
return target_values, target_rewards, target_policies
import ray
import torch
import os
@ray.remote
class SharedStorage:
"""
Class which run in a dedicated thread to store the network weights and some information.
"""
def __init__(self, weights, training_step, game_name, config):
self.config = config
self.game_name = game_name
self.weights = weights
self.infos = {'training_step': training_step,
'total_reward': 0,
'total_loss': 0,
'value_loss': 0,
'reward_loss': 0,
'policy_loss': 0}
def get_weights(self):
return self.weights
def set_weights(self, weights, path=None):
self.weights = weights
if not path:
path = os.path.join(self.config.results_path, self.game_name)
torch.save(self.weights, path)
def get_infos(self):
return self.infos
def set_infos(self, key, value):
self.infos[key] = value
import time
import numpy
import ray
import torch
import models
@ray.remote(num_gpus=1)
class Trainer:
"""
Class which run in a dedicated thread to train a neural network and save it in the shared storage.
"""
def __init__(self, initial_weights, initial_training_step, config, device):
self.config = config
self.training_step = initial_training_step
# Initialize the network
self.model = models.MuZeroNetwork(
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
)
self.model.set_weights(initial_weights)
self.model.to(torch.device(device))
self.model.train()
self.optimizer = torch.optim.SGD(
self.model.parameters(),
lr=self.config.lr_init,
momentum=self.config.momentum,
weight_decay=self.config.weight_decay,
)
def continuous_update_weights(self, replay_buffer, shared_storage_worker):
# Wait for the replay buffer to be filled
while ray.get(replay_buffer.get_self_play_count.remote()) < 1:
time.sleep(0.1)
# Training loop
while True:
batch = ray.get(replay_buffer.get_batch.remote())
total_loss, value_loss, reward_loss, policy_loss = self.update_weights(
batch
)
# Save to the shared storage
if self.training_step % self.config.checkpoint_interval == 0:
shared_storage_worker.set_weights.remote(self.model.get_weights())
shared_storage_worker.set_infos.remote("training_step", self.training_step)
shared_storage_worker.set_infos.remote("total_loss", total_loss)
shared_storage_worker.set_infos.remote("value_loss", value_loss)
shared_storage_worker.set_infos.remote("reward_loss", reward_loss)
shared_storage_worker.set_infos.remote("policy_loss", policy_loss)
def update_weights(self, batch):
"""
Perform one training step.
"""
# Update learning rate
lr = self.config.lr_init * self.config.lr_decay_rate ** (
self.training_step / self.config.lr_decay_steps
)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
(
observation_batch,
action_batch,
target_value,
target_reward,
target_policy,
) = batch
device = next(self.model.parameters()).device
observation_batch = torch.tensor(observation_batch).float().to(device)
action_batch = torch.tensor(action_batch).float().to(device).unsqueeze(-1)
target_value = torch.tensor(target_value).float().to(device)
target_reward = torch.tensor(target_reward).float().to(device)
target_policy = torch.tensor(target_policy).float().to(device)
value, reward, policy_logits, hidden_state = self.model.initial_inference(
observation_batch
)
predictions = [(value, reward, policy_logits)]
for action_i in range(self.config.num_unroll_steps):
value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
hidden_state, action_batch[:, action_i]
)
predictions.append((value, reward, policy_logits))
# Compute losses
value_loss, reward_loss, policy_loss = (0, 0, 0)
for i, prediction in enumerate(predictions):
value, reward, policy_logits = prediction
(
current_value_loss,
current_reward_loss,
current_policy_loss,
) = loss_function(
value.squeeze(-1),
reward.squeeze(-1),
policy_logits,
target_value[:, i],
target_reward[:, i],
target_policy[:, i, :],
)
value_loss += current_value_loss
reward_loss += current_reward_loss
policy_loss += current_policy_loss
# Scale gradient by number of unroll steps (See paper Training appendix)
loss = (
value_loss + reward_loss + policy_loss
).mean() / self.config.num_unroll_steps
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.training_step += 1
return (
loss.item(),
value_loss.mean().item(),
reward_loss.mean().item(),
policy_loss.mean().item(),
)
def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy
):
# TODO: paper promotes cross entropy instead of MSE
value_loss = torch.nn.MSELoss(reduction="none")(value, target_value)
reward_loss = torch.nn.MSELoss(reduction="none")(reward, target_reward)
policy_loss = -(torch.log_softmax(policy_logits, dim=1) * target_policy).sum(1)
return value_loss, reward_loss, policy_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment