Commit ceb98db4 by Werner Duvaud

Add console log

parent 2a7838e4
...@@ -7,26 +7,26 @@ ...@@ -7,26 +7,26 @@
# MuZero General # MuZero General
A flexible, commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v1/anc/pseudocode.py). A flexible, commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v1/anc/pseudocode.py).
It is designed to be easily adaptable for every games or reinforcement learning environments (like [gym](https://github.com/openai/gym)). You only need to edit the game file with the parameters and the game class. Please refer to the documentation and the [example](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py). It is designed to be easily adaptable for every games or reinforcement learning environments (like [gym](https://github.com/openai/gym)). You only need to edit the [game file](https://github.com/werner-duvaud/muzero-general/tree/master/games) with the parameters and the game class. Please refer to the documentation and the [example](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py).
MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games without knowing the rules. It only knows actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122). MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games without knowing the rules. It only knows actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122).
It uses [PyTorch](https://github.com/pytorch/pytorch) and [Ray](https://github.com/ray-project/ray) for running the different components simultaneously. There is a complete GPU support. It uses [PyTorch](https://github.com/pytorch/pytorch) and [Ray](https://github.com/ray-project/ray) for running the different components simultaneously. There is a complete GPU support.
There are four "actors" which are classes that run simultaneously in a dedicated thread.
The shared storage holds the latest neural network weights, the self-play uses those weights to generate self-play games and store them in the replay buffer. Finally, those games are used to train a network and store the weights in the shared storage. The circle is complete.
Those components are launched and managed from the MuZero class in muzero.py and the structure of the neural network is defined in models.py. There are four components which are classes that run simultaneously in a dedicated thread.
The `shared storage` holds the latest neural network weights, the `self-play` uses those weights to generate self-play games and store them in the `replay buffer`. Finally, those played games are used to `train` a network and store the weights in the shared storage. The circle is complete. See [How it works](https://github.com/werner-duvaud/muzero-general/wiki/How-MuZero-works)
All performances are tracked and displayed in real time in tensorboard. Those components are launched and managed from the MuZero class in `muzero.py` and the structure of the neural network is defined in `models.py`.
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/pretrained/cartpole_training_summary.png) All performances are tracked and displayed in real time in tensorboard.
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/docs/cartpole_training_summary.png)
## Games already implemented with pretrained network available ## Games already implemented with pretrained network available
* Lunar Lander * Lunar Lander
* Cartpole * Cartpole
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/games/lunarlander_training_preview.png) ![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/docs/lunarlander_training_preview.png)
## Getting started ## Getting started
### Installation ### Installation
......
# MuZero General Documentation
Please refer to the [GitHub wiki](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) and to the comments in the code.
\ No newline at end of file
...@@ -15,6 +15,7 @@ class MuZeroConfig: ...@@ -15,6 +15,7 @@ class MuZeroConfig:
self.max_moves = 500 # Maximum number of moves if game is not finished before self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = None # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A)
# Root prior exploration noise # Root prior exploration noise
self.root_dirichlet_alpha = 0.25 self.root_dirichlet_alpha = 0.25
...@@ -24,30 +25,28 @@ class MuZeroConfig: ...@@ -24,30 +25,28 @@ class MuZeroConfig:
self.pb_c_base = 19652 self.pb_c_base = 19652
self.pb_c_init = 1.25 self.pb_c_init = 1.25
# If we already have some information about which values occur in the environment, we can use them to initialize the rescaling
# This is not strictly necessary, but establishes identical behaviour to AlphaZero in board games
self.min_known_bound = None
self.max_known_bound = None
### Network ### Network
self.encoding_size = 64 self.encoding_size = 64
self.hidden_size = 32 self.hidden_size = 32
# Training # Training
self.results_path = "./pretrained" # Path to store the model weights self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 1000 # Total number of training steps (ie weights update according to a batch) self.training_steps = 2000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.test_episodes = 2 # Number of game played to evaluate the network
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 1 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A)
self.weight_decay = 1e-4 # L2 weights regularization self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9 self.momentum = 0.9
# Test
self.test_episodes = 2 # Number of game played to evaluate the network
# Exponential learning rate schedule # Exponential learning rate schedule
self.lr_init = 0.005 # Initial learning rate self.lr_init = 0.0005 # Initial learning rate
self.lr_decay_rate = 0.1 self.lr_decay_rate = 0.1
self.lr_decay_steps = 3500 self.lr_decay_steps = 3500
...@@ -59,7 +58,7 @@ class MuZeroConfig: ...@@ -59,7 +58,7 @@ class MuZeroConfig:
Returns: Returns:
Positive float. Positive float.
""" """
if trained_steps < 0.5 * self.training_steps: if trained_steps < 0.25 * self.training_steps:
return 1.0 return 1.0
elif trained_steps < 0.75 * self.training_steps: elif trained_steps < 0.75 * self.training_steps:
return 0.5 return 0.5
......
...@@ -16,6 +16,7 @@ class MuZeroConfig: ...@@ -16,6 +16,7 @@ class MuZeroConfig:
self.max_moves = 500 # Maximum number of moves if game is not finished before self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = None # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A)
# Root prior exploration noise # Root prior exploration noise
self.root_dirichlet_alpha = 0.25 self.root_dirichlet_alpha = 0.25
...@@ -25,30 +26,28 @@ class MuZeroConfig: ...@@ -25,30 +26,28 @@ class MuZeroConfig:
self.pb_c_base = 19652 self.pb_c_base = 19652
self.pb_c_init = 1.25 self.pb_c_init = 1.25
# If we already have some information about which values occur in the environment, we can use them to initialize the rescaling
# This is not strictly necessary, but establishes identical behaviour to AlphaZero in board games
self.min_known_bound = None
self.max_known_bound = None
### Network ### Network
self.encoding_size = 64 self.encoding_size = 64
self.hidden_size = 32 self.hidden_size = 32
# Training # Training
self.results_path = "./pretrained" # Path to store the model weights self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 20000 # Total number of training steps (ie weights update according to a batch) self.training_steps = 2000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.test_episodes = 2 # Number of game played to evaluate the network self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.checkpoint_interval = 20 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 100 # Number of steps in the futur to take into account for calculating the target value self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 8 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid overfitting (Recommended is 13:1 see https://arxiv.org/abs/1902.04522 Appendix A)
self.weight_decay = 1e-4 # L2 weights regularization self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9 self.momentum = 0.9
# Test
self.test_episodes = 2 # Number of game played to evaluate the network
# Exponential learning rate schedule # Exponential learning rate schedule
self.lr_init = 0.005 # Initial learning rate self.lr_init = 0.0001 # Initial learning rate
self.lr_decay_rate = 0.1 self.lr_decay_rate = 0.1
self.lr_decay_steps = 3500 self.lr_decay_steps = 3500
...@@ -60,7 +59,7 @@ class MuZeroConfig: ...@@ -60,7 +59,7 @@ class MuZeroConfig:
Returns: Returns:
Positive float. Positive float.
""" """
if trained_steps < 0.5 * self.training_steps: if trained_steps < 0.25 * self.training_steps:
return 1.0 return 1.0
elif trained_steps < 0.75 * self.training_steps: elif trained_steps < 0.75 * self.training_steps:
return 0.5 return 0.5
......
import torch import torch
class FullyConnectedNetwork(torch.nn.Module): class FullyConnectedNetwork(torch.nn.Module):
def __init__( def __init__(
self, input_size, layers_sizes, output_size, activation=torch.nn.Tanh() self, input_size, layers_sizes, output_size, activation=torch.nn.Tanh()
...@@ -25,6 +26,7 @@ class FullyConnectedNetwork(torch.nn.Module): ...@@ -25,6 +26,7 @@ class FullyConnectedNetwork(torch.nn.Module):
x = layer(x) x = layer(x)
return x return x
# TODO: unified residual network # TODO: unified residual network
class MuZeroNetwork(torch.nn.Module): class MuZeroNetwork(torch.nn.Module):
def __init__(self, observation_size, action_space_size, encoding_size, hidden_size): def __init__(self, observation_size, action_space_size, encoding_size, hidden_size):
...@@ -89,4 +91,4 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -89,4 +91,4 @@ class MuZeroNetwork(torch.nn.Module):
return {key: value.cpu() for key, value in self.state_dict().items()} return {key: value.cpu() for key, value in self.state_dict().items()}
def set_weights(self, weights): def set_weights(self, weights):
self.load_state_dict(weights) self.load_state_dict(weights)
\ No newline at end of file
import copy import copy
import datetime
import importlib import importlib
import os import os
import time import time
...@@ -21,7 +20,8 @@ class MuZero: ...@@ -21,7 +20,8 @@ class MuZero:
Main class to manage MuZero. Main class to manage MuZero.
Args: Args:
game_name (str): Name of the game module, it should match the name of a .py file in the "./games" directory. game_name (str): Name of the game module, it should match the name of a .py file
in the "./games" directory.
Example: Example:
>>> muzero = MuZero("cartpole") >>> muzero = MuZero("cartpole")
...@@ -46,18 +46,16 @@ class MuZero: ...@@ -46,18 +46,16 @@ class MuZero:
raise err raise err
# Fix random generator seed for reproductibility # Fix random generator seed for reproductibility
# TODO: check if results do not change from one run to another
numpy.random.seed(self.config.seed) numpy.random.seed(self.config.seed)
torch.manual_seed(self.config.seed) torch.manual_seed(self.config.seed)
# Used to initialize components when continuing a former training # Initial weights used to initialize components
self.muzero_weights = models.MuZeroNetwork( self.muzero_weights = models.MuZeroNetwork(
self.config.observation_shape, self.config.observation_shape,
len(self.config.action_space), len(self.config.action_space),
self.config.encoding_size, self.config.encoding_size,
self.config.hidden_size, self.config.hidden_size,
).get_weights() ).get_weights()
self.training_steps = 0
def train(self): def train(self):
ray.init() ray.init()
...@@ -68,16 +66,12 @@ class MuZero: ...@@ -68,16 +66,12 @@ class MuZero:
# Initialize workers # Initialize workers
training_worker = trainer.Trainer.remote( training_worker = trainer.Trainer.remote(
copy.deepcopy(self.muzero_weights), copy.deepcopy(self.muzero_weights),
self.training_steps,
self.config, self.config,
# Train on GPU if available # Train on GPU if available
"cuda" if torch.cuda.is_available() else "cpu", "cuda" if torch.cuda.is_available() else "cpu",
) )
shared_storage_worker = shared_storage.SharedStorage.remote( shared_storage_worker = shared_storage.SharedStorage.remote(
copy.deepcopy(self.muzero_weights), copy.deepcopy(self.muzero_weights), self.game_name, self.config,
self.training_steps,
self.game_name,
self.config,
) )
replay_buffer_worker = replay_buffer.ReplayBuffer.remote(self.config) replay_buffer_worker = replay_buffer.ReplayBuffer.remote(self.config)
self_play_workers = [ self_play_workers = [
...@@ -107,7 +101,7 @@ class MuZero: ...@@ -107,7 +101,7 @@ class MuZero:
# Loop for monitoring in real time the workers # Loop for monitoring in real time the workers
print( print(
"Run tensorboard --logdir ./ and go to http://localhost:6006/ to track the training performance" "\nTraining...\nRun tensorboard --logdir ./ and go to http://localhost:6006/ to see in real time the training performance.\n"
) )
counter = 0 counter = 0
infos = ray.get(shared_storage_worker.get_infos.remote()) infos = ray.get(shared_storage_worker.get_infos.remote())
...@@ -126,15 +120,21 @@ class MuZero: ...@@ -126,15 +120,21 @@ class MuZero:
"2.Workers/Training steps", infos["training_step"], counter "2.Workers/Training steps", infos["training_step"], counter
) )
writer.add_scalar("3.Loss/1.Total loss", infos["total_loss"], counter) writer.add_scalar("3.Loss/1.Total loss", infos["total_loss"], counter)
writer.add_scalar("3.Loss details/Value loss", infos["value_loss"], counter) writer.add_scalar("3.Loss/Value loss", infos["value_loss"], counter)
writer.add_scalar( writer.add_scalar("3.Loss/Reward loss", infos["reward_loss"], counter)
"3.Loss details/Reward loss", infos["reward_loss"], counter writer.add_scalar("3.Loss/Policy loss", infos["policy_loss"], counter)
) print(
writer.add_scalar( "Last test reward: {0:.2f}. Training step: {1}/{2}. Played games: {3}. Loss: {4:.2f}".format(
"3.Loss details/Policy loss", infos["policy_loss"], counter infos["total_reward"],
infos["training_step"],
self.config.training_steps,
ray.get(replay_buffer_worker.get_self_play_count.remote()),
infos["total_loss"],
),
end="\r",
) )
counter += 1 counter += 1
time.sleep(1) time.sleep(3)
self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote()) self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
ray.shutdown() ray.shutdown()
...@@ -142,6 +142,7 @@ class MuZero: ...@@ -142,6 +142,7 @@ class MuZero:
""" """
Test the model in a dedicated thread. Test the model in a dedicated thread.
""" """
print("Testing...")
ray.init() ray.init()
self_play_workers = self_play.SelfPlay.remote( self_play_workers = self_play.SelfPlay.remote(
copy.deepcopy(self.muzero_weights), self.Game(), self.config, "cpu", copy.deepcopy(self.muzero_weights), self.Game(), self.config, "cpu",
...@@ -149,18 +150,17 @@ class MuZero: ...@@ -149,18 +150,17 @@ class MuZero:
test_rewards = [] test_rewards = []
with torch.no_grad(): with torch.no_grad():
for _ in range(self.config.test_episodes): for _ in range(self.config.test_episodes):
history = ray.get(self_play_workers.self_play.remote(0, render)) history = ray.get(self_play_workers.play_game.remote(0, render))
test_rewards.append(sum(history.rewards)) test_rewards.append(sum(history.rewards))
ray.shutdown() ray.shutdown()
return test_rewards return test_rewards
def load_model(self, path=None, training_step=0): def load_model(self, path=None):
# TODO: why pretrained model is degradated during the new train
if not path: if not path:
path = os.path.join(self.config.results_path, self.game_name) path = os.path.join(self.config.results_path, self.game_name)
try: try:
self.muzero_weights = torch.load(path) self.muzero_weights = torch.load(path)
self.training_step = training_step print("Using weights from {}".format(path))
except FileNotFoundError: except FileNotFoundError:
print("There is no model saved in {}.".format(path)) print("There is no model saved in {}.".format(path))
...@@ -168,5 +168,6 @@ class MuZero: ...@@ -168,5 +168,6 @@ class MuZero:
if __name__ == "__main__": if __name__ == "__main__":
muzero = MuZero("cartpole") muzero = MuZero("cartpole")
muzero.train() muzero.train()
# muzero.load_model()
muzero.load_model()
muzero.test() muzero.test()
{ {
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2, "nbformat_minor": 0,
"metadata": { "metadata": {
"language_info": { "colab": {
"name": "python", "name": "Untitled3.ipynb",
"codemirror_mode": { "provenance": []
"name": "ipython", },
"version": 3 "kernelspec": {
} "name": "python3",
"display_name": "Python 3"
}
}, },
"orig_nbformat": 2, "cells": [
"file_extension": ".py", {
"mimetype": "text/x-python", "cell_type": "code",
"name": "python", "execution_count": null,
"npconvert_exporter": "python", "metadata": {},
"pygments_lexer": "ipython3", "outputs": [],
"version": 3 "source": [
}, "# Google colab stuffs\n",
"cells": [ "!pip install -r requirements.txt\n",
{ "!pip uninstall -y pyarrow\n",
"cell_type": "code", "%load_ext tensorboard\n",
"execution_count": null, "# If you have an import issue with ray in google colab, restart the environment (execution menu)"
"metadata": {}, ]
"outputs": [], },
"source": [ {
"# Google colab imports\n", "cell_type": "code",
"!pip install -r requirements.txt\n", "execution_count": null,
"!pip uninstall -y pyarrow\n", "metadata": {},
"# If you have an import issue with ray, restart the environment (execution menu)\n", "outputs": [],
"\n", "source": [
"# You must have the repository imported along with your notebook. \n", "# You must have the repository imported along with your notebook. \n",
"# For google colab, click on \">\" buttton (left) and import files (muzero.py, self_play.py, ...).\n", "# For google colab, click on \">\" buttton (left) and import files (muzero.py, self_play.py, ...).\n",
"import muzero as mz" "\n",
] "import muzero as mz"
}, ]
{ },
"cell_type": "code", {
"execution_count": null, "cell_type": "code",
"metadata": {}, "execution_count": null,
"outputs": [], "metadata": {},
"source": [ "outputs": [],
"#Train on cartpole game\n", "source": [
"muzero = mz.MuZero(\"cartpole\")\n", "#Train on cartpole game\n",
"muzero.load_model()\n", "muzero = mz.MuZero(\"cartpole\")\n",
"muzero.train()\n", "muzero.train()\n",
"muzero.test()" "muzero.test()"
] ]
} }
] ]
} }
\ No newline at end of file
No preview for this file type
No preview for this file type
...@@ -62,7 +62,8 @@ def sample_game(buffer): ...@@ -62,7 +62,8 @@ def sample_game(buffer):
""" """
Sample game from buffer either uniformly or according to some priority. Sample game from buffer either uniformly or according to some priority.
""" """
# TODO: sample with probability link to the highest difference between real and predicted value (see paper appendix Training) # TODO: sample with probability link to the highest difference between real and
# predicted value (see paper appendix Training)
return numpy.random.choice(buffer) return numpy.random.choice(buffer)
...@@ -76,7 +77,8 @@ def sample_position(game_history): ...@@ -76,7 +77,8 @@ def sample_position(game_history):
def make_target(game_history, state_index, num_unroll_steps, td_steps): def make_target(game_history, state_index, num_unroll_steps, td_steps):
""" """
The value target is the discounted root value of the search tree td_steps into the future, plus the discounted sum of all rewards until then. The value target is the discounted root value of the search tree td_steps into the
future, plus the discounted sum of all rewards until then.
""" """
target_values, target_rewards, target_policies = [], [], [] target_values, target_rewards, target_policies = [], [], []
for current_index in range(state_index, state_index + num_unroll_steps + 1): for current_index in range(state_index, state_index + num_unroll_steps + 1):
......
import math import math
import time
import copy
import numpy import numpy
import ray import ray
import torch import torch
...@@ -12,6 +13,7 @@ class SelfPlay: ...@@ -12,6 +13,7 @@ class SelfPlay:
""" """
Class which run in a dedicated thread to play games and save them to the replay-buffer. Class which run in a dedicated thread to play games and save them to the replay-buffer.
""" """
def __init__(self, initial_weights, game, config, device): def __init__(self, initial_weights, game, config, device):
self.config = config self.config = config
self.game = game self.game = game
...@@ -25,11 +27,14 @@ class SelfPlay: ...@@ -25,11 +27,14 @@ class SelfPlay:
) )
self.model.set_weights(initial_weights) self.model.set_weights(initial_weights)
self.model.to(torch.device(device)) self.model.to(torch.device(device))
self.model.eval()
def continuous_self_play(self, shared_storage, replay_buffer, test_mode=False): def continuous_self_play(self, shared_storage, replay_buffer, test_mode=False):
with torch.no_grad(): with torch.no_grad():
while True: while True:
self.model.set_weights(ray.get(shared_storage.get_weights.remote())) self.model.set_weights(
copy.deepcopy(ray.get(shared_storage.get_weights.remote()))
)
# Take the best action (no exploration) in test mode # Take the best action (no exploration) in test mode
temperature = ( temperature = (
...@@ -41,8 +46,7 @@ class SelfPlay: ...@@ -41,8 +46,7 @@ class SelfPlay:
] ]
) )
) )
game_history = self.play_game(temperature, False)
game_history = self.self_play(temperature, False)
# Save to the shared storage # Save to the shared storage
if test_mode: if test_mode:
...@@ -52,7 +56,10 @@ class SelfPlay: ...@@ -52,7 +56,10 @@ class SelfPlay:
if not test_mode: if not test_mode:
replay_buffer.save_game.remote(game_history) replay_buffer.save_game.remote(game_history)
def self_play(self, temperature, render): if not test_mode and self.config.self_play_delay:
time.sleep(self.config.self_play_delay)
def play_game(self, temperature, render):
""" """
Play one game with actions based on the Monte Carlo tree search at each moves. Play one game with actions based on the Monte Carlo tree search at each moves.
""" """
...@@ -83,7 +90,8 @@ class SelfPlay: ...@@ -83,7 +90,8 @@ class SelfPlay:
def select_action(node, temperature): def select_action(node, temperature):
""" """
Select action according to the vivist count distribution and the temperature. Select action according to the vivist count distribution and the temperature.
The temperature is changed dynamically with the visit_softmax_temperature function in the config. The temperature is changed dynamically with the visit_softmax_temperature function
in the config.
""" """
visit_counts = numpy.array( visit_counts = numpy.array(
[[child.visit_count, action] for action, child in node.children.items()] [[child.visit_count, action] for action, child in node.children.items()]
...@@ -120,8 +128,10 @@ class MCTS: ...@@ -120,8 +128,10 @@ class MCTS:
def run(self, model, observation, add_exploration_noise): def run(self, model, observation, add_exploration_noise):
""" """
At the root of the search tree we use the representation function to obtain a hidden state given the current observation. At the root of the search tree we use the representation function to obtain a
We then run a Monte Carlo Tree Search using only action sequences and the model learned by the network. hidden state given the current observation.
We then run a Monte Carlo Tree Search using only action sequences and the model
learned by the network.
""" """
root = Node(0) root = Node(0)
observation = ( observation = (
...@@ -153,7 +163,8 @@ class MCTS: ...@@ -153,7 +163,8 @@ class MCTS:
last_action = action last_action = action
search_path.append(node) search_path.append(node)
# Inside the search tree we use the dynamics function to obtain the next hidden state given an action and the previous hidden state # Inside the search tree we use the dynamics function to obtain the next hidden
# state given an action and the previous hidden state
parent = search_path[-2] parent = search_path[-2]
value, reward, policy_logits, hidden_state = model.recurrent_inference( value, reward, policy_logits, hidden_state = model.recurrent_inference(
parent.hidden_state, parent.hidden_state,
...@@ -194,10 +205,12 @@ class MCTS: ...@@ -194,10 +205,12 @@ class MCTS:
def backpropagate(self, search_path, value, min_max_stats): def backpropagate(self, search_path, value, min_max_stats):
""" """
At the end of a simulation, we propagate the evaluation all the way up the tree to the root. At the end of a simulation, we propagate the evaluation all the way up the tree
to the root.
""" """
for node in search_path: for node in search_path:
# Always the same player, the other players minds should be modeled in network because environment do not act always in the best way to make you lose # Always the same player, the other players minds should be modeled in network
# because environment do not act always in the best way to make you lose
node.value_sum += value # if node.to_play == to_play else -value node.value_sum += value # if node.to_play == to_play else -value
node.visit_count += 1 node.visit_count += 1
min_max_stats.update(node.value()) min_max_stats.update(node.value())
...@@ -225,7 +238,8 @@ class Node: ...@@ -225,7 +238,8 @@ class Node:
def expand(self, actions, reward, policy_logits, hidden_state): def expand(self, actions, reward, policy_logits, hidden_state):
""" """
We expand a node using the value, reward and policy prediction obtained from the neural network. We expand a node using the value, reward and policy prediction obtained from the
neural network.
""" """
self.reward = reward self.reward = reward
self.hidden_state = hidden_state self.hidden_state = hidden_state
...@@ -236,7 +250,8 @@ class Node: ...@@ -236,7 +250,8 @@ class Node:
def add_exploration_noise(self, dirichlet_alpha, exploration_fraction): def add_exploration_noise(self, dirichlet_alpha, exploration_fraction):
""" """
At the start of each search, we add dirichlet noise to the prior of the root to encourage the search to explore new actions. At the start of each search, we add dirichlet noise to the prior of the root to
encourage the search to explore new actions.
""" """
actions = list(self.children.keys()) actions = list(self.children.keys())
noise = numpy.random.dirichlet([dirichlet_alpha] * len(actions)) noise = numpy.random.dirichlet([dirichlet_alpha] * len(actions))
......
...@@ -2,21 +2,25 @@ import ray ...@@ -2,21 +2,25 @@ import ray
import torch import torch
import os import os
@ray.remote @ray.remote
class SharedStorage: class SharedStorage:
""" """
Class which run in a dedicated thread to store the network weights and some information. Class which run in a dedicated thread to store the network weights and some information.
""" """
def __init__(self, weights, training_step, game_name, config):
def __init__(self, weights, game_name, config):
self.config = config self.config = config
self.game_name = game_name self.game_name = game_name
self.weights = weights self.weights = weights
self.infos = {'training_step': training_step, self.infos = {
'total_reward': 0, "training_step": 0,
'total_loss': 0, "total_reward": 0,
'value_loss': 0, "total_loss": 0,
'reward_loss': 0, "value_loss": 0,
'policy_loss': 0} "reward_loss": 0,
"policy_loss": 0,
}
def get_weights(self): def get_weights(self):
return self.weights return self.weights
...@@ -25,6 +29,7 @@ class SharedStorage: ...@@ -25,6 +29,7 @@ class SharedStorage:
self.weights = weights self.weights = weights
if not path: if not path:
path = os.path.join(self.config.results_path, self.game_name) path = os.path.join(self.config.results_path, self.game_name)
torch.save(self.weights, path) torch.save(self.weights, path)
def get_infos(self): def get_infos(self):
......
...@@ -10,11 +10,13 @@ import models ...@@ -10,11 +10,13 @@ import models
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
class Trainer: class Trainer:
""" """
Class which run in a dedicated thread to train a neural network and save it in the shared storage. Class which run in a dedicated thread to train a neural network and save it
in the shared storage.
""" """
def __init__(self, initial_weights, initial_training_step, config, device):
def __init__(self, initial_weights, config, device):
self.config = config self.config = config
self.training_step = initial_training_step self.training_step = 0
# Initialize the network # Initialize the network
self.model = models.MuZeroNetwork( self.model = models.MuZeroNetwork(
...@@ -38,7 +40,7 @@ class Trainer: ...@@ -38,7 +40,7 @@ class Trainer:
# Wait for the replay buffer to be filled # Wait for the replay buffer to be filled
while ray.get(replay_buffer.get_self_play_count.remote()) < 1: while ray.get(replay_buffer.get_self_play_count.remote()) < 1:
time.sleep(0.1) time.sleep(0.1)
# Training loop # Training loop
while True: while True:
batch = ray.get(replay_buffer.get_batch.remote()) batch = ray.get(replay_buffer.get_batch.remote())
...@@ -55,6 +57,9 @@ class Trainer: ...@@ -55,6 +57,9 @@ class Trainer:
shared_storage_worker.set_infos.remote("reward_loss", reward_loss) shared_storage_worker.set_infos.remote("reward_loss", reward_loss)
shared_storage_worker.set_infos.remote("policy_loss", policy_loss) shared_storage_worker.set_infos.remote("policy_loss", policy_loss)
if self.config.training_delay:
time.sleep(self.config.training_delay)
def update_weights(self, batch): def update_weights(self, batch):
""" """
Perform one training step. Perform one training step.
...@@ -66,7 +71,6 @@ class Trainer: ...@@ -66,7 +71,6 @@ class Trainer:
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
param_group["lr"] = lr param_group["lr"] = lr
( (
observation_batch, observation_batch,
action_batch, action_batch,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment