Commit 4b81586b by Werner Duvaud

Add multiplayer mode

parent c9059c39
......@@ -6,31 +6,52 @@
# MuZero General
A flexible, commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v1/anc/pseudocode.py).
It is designed to be easily adaptable for every games or reinforcement learning environments (like [gym](https://github.com/openai/gym)). You only need to edit the [game file](https://github.com/werner-duvaud/muzero-general/tree/master/games) with the parameters and the game class. Please refer to the documentation and the [example](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py).
A commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v1/anc/pseudocode.py).
It is designed to be easily adaptable for every games or reinforcement learning environments (like [gym](https://github.com/openai/gym)). You only need to edit the [game file](https://github.com/werner-duvaud/muzero-general/tree/master/games) with the parameters and the game class. Please refer to the [documentation](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) and the [example](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py).
MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games without knowing the rules. It only knows actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122).
MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games without knowing the rules. It only knows actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122). See [How it works](https://github.com/werner-duvaud/muzero-general/wiki/How-MuZero-works)
It uses [PyTorch](https://github.com/pytorch/pytorch) and [Ray](https://github.com/ray-project/ray) for running the different components simultaneously. GPU training is supported. See [How it works](https://github.com/werner-duvaud/muzero-general/wiki/How-MuZero-works)
## Features
All performances are tracked and displayed in real time in tensorboard.
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/docs/cartpole_training_summary.png)
* [x] Fully connected network in [PyTorch](https://github.com/pytorch/pytorch)
* [x] Multi-Threaded with [Ray](https://github.com/ray-project/ray)
* [x] CPU/GPU support
* [x] TensorBoard real-time monitoring
* [x] Single and multiplayer mode
* [x] Commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation)
* [x] Easily adaptable for new games
* [x] [Examples](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py) of board and Gym games (See [list below](https://github.com/werner-duvaud/muzero-general#games-already-implemented-with-pretrained-network-available))
* [x] [Pretrained weights](https://github.com/werner-duvaud/muzero-general/tree/master/pretrained) available
* [ ] Play against MuZero mode with policy and value tracking
* [ ] Residual Network
* [ ] Atari games
## Games already implemented with pretrained network available
* Lunar Lander
* Cartpole
* Lunar Lander
* Connect4
## Demo
All performances are tracked and displayed in real time in tensorboard :
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/docs/cartpole_training_summary.png)
Testing Lunar Lander :
![lunarlander training preview](https://github.com/werner-duvaud/muzero-general/blob/master/docs/lunarlander_training_preview.png)
## Getting started
### Installation
```bash
cd muzero-general
pip install -r requirements.txt
```
### Training
Edit the end of muzero.py:
```python
muzero = Muzero("cartpole")
......@@ -40,12 +61,13 @@ Then run:
```bash
python muzero.py
```
To visualize the training results, run in a new bash:
To visualize the training results, run in a new terminal:
```bash
tensorboard --logdir ./
```
### Testing
Edit the end of muzero.py:
```python
muzero = Muzero("cartpole")
......@@ -57,14 +79,8 @@ Then run:
python muzero.py
```
## Coming soon
* [ ] Atari mode with residual network
* [ ] Live test policy & value tracking
* [ ] [Open spiel](https://github.com/deepmind/open_spiel) integration
* [ ] Checkers game
* [ ] TensorFlow mode
## Authors
* Werner Duvaud
* Aurèle Hainaut
* Paul Lenoir
......@@ -11,6 +11,7 @@ class MuZeroConfig:
### Game
self.observation_shape = 4 # Dimensions of the game observation
self.action_space = [i for i in range(2)] # Fixed list of all possible actions
self.players = [i for i in range(1)] # List of players
### Self-Play
......@@ -95,7 +96,16 @@ class Game:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
return numpy.array(observation).flatten(), reward, done
return observation, reward, done
def to_play(self):
"""
Return the current player.
Returns:
The current player, it should be an element of the players list in the config.
"""
return 0
def reset(self):
"""
......
import gym
import numpy
import torch
class MuZeroConfig:
def __init__(self):
self.seed = 0 # Seed for numpy, torch and the game
### Game
self.observation_shape = 6 * 7 # Dimensions of the game observation
self.action_space = [i for i in range(7)] # Fixed list of all possible actions
self.players = [i for i in range(2)] # List of players
### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 50 # Maximum number of moves if game is not finished before
self.num_simulations = 30 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
# Root prior exploration noise
self.root_dirichlet_alpha = 0.25
self.root_exploration_fraction = 0.25
# UCB formula
self.pb_c_base = 19652
self.pb_c_init = 1.25
### Network
self.encoding_size = 32
self.hidden_size = 64
### Training
self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting
self.training_device = "cuda" if torch.cuda.is_available() else "cpu" # Train on GPU if available
self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9
# Exponential learning rate schedule
self.lr_init = 0.05 # Initial learning rate
self.lr_decay_rate = 0.01
self.lr_decay_steps = 10000
### Test
self.test_episodes = 2 # Number of game played to evaluate the network
def visit_softmax_temperature_fn(self, trained_steps):
"""
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
The smaller it is, the more likely the best action (ie with the highest visit count) is chosen.
Returns:
Positive float.
"""
if trained_steps < 0.5 * self.training_steps:
return 1.0
elif trained_steps < 0.75 * self.training_steps:
return 0.5
else:
return 0.25
class Game:
"""
Game wrapper.
"""
def __init__(self, seed=None):
self.env = Connect4()
def step(self, action):
"""
Apply action to the game.
Args:
action : action of the action_space to take.
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
if action not in self.env.legal_actions():
observation, reward, done = self.env.step(self.env.legal_actions()[0])
reward = -1
done = True
else:
observation, reward, done = self.env.step(action)
return numpy.array(observation).flatten(), reward, done
def to_play(self):
"""
Return the current player.
Returns:
The current player, it should be an element of the players list in the config.
"""
return self.env.to_play()
def reset(self):
"""
Reset the game for a new game.
Returns:
Initial observation of the game.
"""
return numpy.array(self.env.reset()).flatten()
def close(self):
"""
Properly close the game.
"""
pass
def render(self):
"""
Display the game observation.
"""
self.env.render()
input("Press enter to take a step ")
class Connect4:
def __init__(self):
self.board = numpy.zeros((6, 7)).astype(int)
self.player = 1
def to_play(self):
return 0 if self.player == 1 else 1
def reset(self):
self.board = numpy.zeros((6, 7)).astype(int)
self.player = 1
return self.get_observation()
def step(self, action):
for i in range(6):
if self.board[i][action] == 0:
self.board[i][action] = self.player
break
done = self.is_finished()
self.player *= -1
return self.get_observation(), 1 if done else 0, done
def get_observation(self):
if self.player == 1:
return self.board
else:
return -self.board
def legal_actions(self):
legal = []
for i in range(7):
for j in range(6):
if self.board[j][i] == 0:
legal.append(i)
break
return legal
def is_finished(self):
# Horizontal check
for i in range(4):
for j in range(6):
if (
self.board[j][i] == self.player
and self.board[j][i + 1] == self.player
and self.board[j][i + 2] == self.player
and self.board[j][i + 3] == self.player
):
return True
# Vertical check
for i in range(7):
for j in range(3):
if (
self.board[j][i] == self.player
and self.board[j + 1][i] == self.player
and self.board[j + 2][i] == self.player
and self.board[j + 3][i] == self.player
):
return True
# x diag check
for i in range(4):
for j in range(3):
if (
self.board[j][i] == self.player
and self.board[j + 1][i + 1] == self.player
and self.board[j + 2][i + 2] == self.player
and self.board[j + 3][i + 3] == self.player
):
return True
# -x diag check
for i in range(4):
for j in range(3, 6):
if (
self.board[j][i] == self.player
and self.board[j - 1][i + 1] == self.player
and self.board[j - 2][i + 2] == self.player
and self.board[j - 3][i + 3] == self.player
):
return True
return False
def render(self):
print(self.player * self.get_observation()[::-1])
......@@ -11,12 +11,13 @@ class MuZeroConfig:
### Game
self.observation_shape = 8 # Dimensions of the game observation
self.action_space = [i for i in range(4)] # Fixed list of all possible actions
self.players = [i for i in range(1)] # List of players
### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 80 # Number of futur moves self-simulated
self.max_moves = 200 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
......@@ -95,7 +96,16 @@ class Game:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
return numpy.array(observation).flatten(), reward, done
return numpy.array(observation).flatten(), reward/100, done
def to_play(self):
"""
Return the current player.
Returns:
The current player, it should be an element of the players list in the config.
"""
return 0
def reset(self):
"""
......
......@@ -61,7 +61,13 @@ class MuZeroNetwork(torch.nn.Module):
return policy_logit, value
def representation(self, observation):
return self.representation_network(observation)
encoded_state = self.representation_network(observation)
# TODO: Try to remove tanh activation
# Scale encoded state between [0, 1] (See appendix paper Training)
encoded_state = (encoded_state - torch.min(encoded_state)) / (
torch.max(encoded_state) - torch.min(encoded_state)
)
return encoded_state
def dynamics(self, encoded_state, action):
# Stack encoded_state with one hot action (See paper appendix Network Architecture)
......@@ -74,15 +80,16 @@ class MuZeroNetwork(torch.nn.Module):
x = torch.cat((encoded_state, action_one_hot), dim=1)
next_encoded_state = self.dynamics_encoded_state_network(x)
# TODO: Try to remove tanh activation
# Scale encoded state between [0, 1] (See paper appendix Training)
next_encoded_state = (next_encoded_state - torch.min(next_encoded_state)) / (
torch.max(next_encoded_state) - torch.min(next_encoded_state)
)
reward = self.dynamics_reward_network(x)
return next_encoded_state, reward
def initial_inference(self, observation):
encoded_state = self.representation(observation)
# Scale encoded state between [0, 1] (See paper Training appendix)
encoded_state = (encoded_state - torch.min(encoded_state)) / (
torch.max(encoded_state) - torch.min(encoded_state)
)
policy_logit, value = self.prediction(encoded_state)
return (
value,
......@@ -93,10 +100,6 @@ class MuZeroNetwork(torch.nn.Module):
def recurrent_inference(self, encoded_state, action):
next_encoded_state, reward = self.dynamics(encoded_state, action)
# Scale encoded state between [0, 1] (See paper Training appendix)
next_encoded_state = (next_encoded_state - torch.min(next_encoded_state)) / (
torch.max(next_encoded_state) - torch.min(next_encoded_state)
)
policy_logit, value = self.prediction(next_encoded_state)
return value, reward, policy_logit, next_encoded_state
......
......@@ -45,10 +45,7 @@ class MuZero:
)
raise err
try:
os.mkdir(os.path.join(self.config.results_path))
except FileExistsError:
pass
os.makedirs(os.path.join(self.config.results_path), exist_ok=True)
# Fix random generator seed for reproductibility
numpy.random.seed(self.config.seed)
......@@ -69,9 +66,9 @@ class MuZero:
)
# Initialize workers
training_worker = trainer.Trainer.remote(
copy.deepcopy(self.muzero_weights), self.config
)
training_worker = trainer.Trainer.options(
num_gpus=1 if "cuda" in self.config.training_device else 0
).remote(copy.deepcopy(self.muzero_weights), self.config)
shared_storage_worker = shared_storage.SharedStorage.remote(
copy.deepcopy(self.muzero_weights), self.game_name, self.config,
)
......@@ -106,6 +103,7 @@ class MuZero:
)
counter = 0
infos = ray.get(shared_storage_worker.get_infos.remote())
try:
while infos["training_step"] < self.config.training_steps:
# Get and save real time performance
infos = ray.get(shared_storage_worker.get_infos.remote())
......@@ -136,6 +134,8 @@ class MuZero:
)
counter += 1
time.sleep(3)
except KeyboardInterrupt:
pass
self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
ray.shutdown()
......@@ -149,7 +149,6 @@ class MuZero:
copy.deepcopy(self.muzero_weights), self.Game(), self.config
)
test_rewards = []
with torch.no_grad():
for _ in range(self.config.test_episodes):
history = ray.get(self_play_workers.play_game.remote(0, render))
test_rewards.append(sum(history.rewards))
......@@ -170,5 +169,5 @@ if __name__ == "__main__":
muzero = MuZero("cartpole")
muzero.train()
#muzero.load_model()
muzero.load_model()
muzero.test()
......@@ -33,24 +33,17 @@ class ReplayBuffer:
for _ in range(self.config.batch_size):
game_history = sample_game(self.buffer)
game_pos = sample_position(game_history)
actions = game_history.history[
game_pos : game_pos + self.config.num_unroll_steps
]
# Repeat precedent action to make 'actions' of length 'num_unroll_steps'
actions.extend(
[
actions[-1]
for _ in range(self.config.num_unroll_steps - len(actions) + 1)
]
)
observation_batch.append(game_history.observation_history[game_pos])
action_batch.append(actions)
value, reward, policy = make_target(
value, reward, policy, actions = make_target(
game_history,
game_pos,
self.config.num_unroll_steps,
self.config.td_steps,
self.config.discount,
)
observation_batch.append(game_history.observation_history[game_pos])
action_batch.append(actions)
value_batch.append(value)
reward_batch.append(reward)
policy_batch.append(policy)
......@@ -63,7 +56,7 @@ def sample_game(buffer):
Sample game from buffer either uniformly or according to some priority.
"""
# TODO: sample with probability link to the highest difference between real and
# predicted value (see paper appendix Training)
# predicted value (See paper appendix Training)
return numpy.random.choice(buffer)
......@@ -75,29 +68,29 @@ def sample_position(game_history):
return numpy.random.choice(range(len(game_history.rewards)))
def make_target(game_history, state_index, num_unroll_steps, td_steps):
def make_target(game_history, state_index, num_unroll_steps, td_steps, discount):
"""
The value target is the discounted root value of the search tree td_steps into the
future, plus the discounted sum of all rewards until then.
"""
target_values, target_rewards, target_policies = [], [], []
target_values, target_rewards, target_policies, actions = [], [], [], []
for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps
if bootstrap_index < len(game_history.root_values):
value = (
game_history.root_values[bootstrap_index]
* game_history.discount ** td_steps
)
value = game_history.root_values[bootstrap_index] * discount ** td_steps
else:
value = 0
for i, reward in enumerate(game_history.rewards[current_index:bootstrap_index]):
value += reward * game_history.discount ** i
value += reward * discount ** i
if current_index < len(game_history.root_values):
target_values.append(0.25 * value)
# TODO: Scale and transform the reward and the value, don't forget the invert transform in inference time (See paper appendix Network Architecture)
# Value target could be scaled by 0.25 (See paper appendix Reanalyze)
target_values.append(value)
target_rewards.append(game_history.rewards[current_index])
target_policies.append(game_history.child_visits[current_index])
actions.append(game_history.action_history[current_index])
else:
# States past the end of games are treated as absorbing states
target_values.append(0)
......@@ -109,5 +102,6 @@ def make_target(game_history, state_index, num_unroll_steps, td_steps):
for _ in range(len(game_history.child_visits[0]))
]
)
actions.append(game_history.action_history[-1])
return target_values, target_rewards, target_policies
return target_values, target_rewards, target_policies, actions
import copy
import math
import time
import copy
import numpy
import ray
import torch
......@@ -26,11 +27,10 @@ class SelfPlay:
self.config.hidden_size,
)
self.model.set_weights(initial_weights)
self.model.to(torch.device('cpu'))
self.model.to(torch.device("cpu"))
self.model.eval()
def continuous_self_play(self, shared_storage, replay_buffer, test_mode=False):
with torch.no_grad():
while True:
self.model.set_weights(
copy.deepcopy(ray.get(shared_storage.get_weights.remote()))
......@@ -63,12 +63,18 @@ class SelfPlay:
"""
Play one game with actions based on the Monte Carlo tree search at each moves.
"""
game_history = GameHistory(self.config.discount)
game_history = GameHistory()
observation = self.game.reset()
game_history.observation_history.append(observation)
done = False
while not done and len(game_history.history) < self.config.max_moves:
root = MCTS(self.config).run(self.model, observation, True if temperature else False)
with torch.no_grad():
while not done and len(game_history.action_history) < self.config.max_moves:
root = MCTS(self.config).run(
self.model,
observation,
self.game.to_play(),
True if temperature else False,
)
action = select_action(root, temperature)
......@@ -79,7 +85,7 @@ class SelfPlay:
game_history.observation_history.append(observation)
game_history.rewards.append(reward)
game_history.history.append(action)
game_history.action_history.append(action)
game_history.store_search_statistics(root, self.config.action_space)
self.game.close()
......@@ -98,7 +104,7 @@ def select_action(node, temperature):
if temperature == 0:
action_pos = numpy.argmax(visit_counts[0])
else:
# See paper Data Generation appendix
# See paper appendix Data Generation
visit_count_distribution = visit_counts[0] ** (1 / temperature)
visit_count_distribution = visit_count_distribution / sum(
visit_count_distribution
......@@ -125,7 +131,7 @@ class MCTS:
def __init__(self, config):
self.config = config
def run(self, model, observation, add_exploration_noise):
def run(self, model, observation, to_play, add_exploration_noise):
"""
At the root of the search tree we use the representation function to obtain a
hidden state given the current observation.
......@@ -143,7 +149,11 @@ class MCTS:
observation
)
root.expand(
self.config.action_space, expected_reward, policy_logits, hidden_state
self.config.action_space,
to_play,
expected_reward,
policy_logits,
hidden_state,
)
if add_exploration_noise:
root.add_exploration_noise(
......@@ -154,6 +164,7 @@ class MCTS:
min_max_stats = MinMaxStats()
for _ in range(self.config.num_simulations):
virtual_to_play = to_play
node = root
search_path = [node]
......@@ -162,6 +173,12 @@ class MCTS:
last_action = action
search_path.append(node)
# Players play turn by turn
if virtual_to_play + 1 < len(self.config.players):
virtual_to_play = self.config.players[virtual_to_play + 1]
else:
virtual_to_play = self.config.players[0]
# Inside the search tree we use the dynamics function to obtain the next hidden
# state given an action and the previous hidden state
parent = search_path[-2]
......@@ -169,9 +186,15 @@ class MCTS:
parent.hidden_state,
torch.tensor([[last_action]]).to(parent.hidden_state.device),
)
node.expand(self.config.action_space, reward, policy_logits, hidden_state)
node.expand(
self.config.action_space,
virtual_to_play,
reward,
policy_logits,
hidden_state,
)
self.backpropagate(search_path, value.item(), min_max_stats)
self.backpropagate(search_path, value.item(), to_play, min_max_stats)
return root
......@@ -202,13 +225,13 @@ class MCTS:
return prior_score + value_score
def backpropagate(self, search_path, value, min_max_stats):
def backpropagate(self, search_path, value, to_play, min_max_stats):
"""
At the end of a simulation, we propagate the evaluation all the way up the tree
to the root.
"""
for node in search_path:
node.value_sum += value # if node.to_play == to_play else -value
node.value_sum += value if node.to_play == to_play else -value
node.visit_count += 1
min_max_stats.update(node.value())
......@@ -233,11 +256,12 @@ class Node:
return 0
return self.value_sum / self.visit_count
def expand(self, actions, reward, policy_logits, hidden_state):
def expand(self, actions, to_play, reward, policy_logits, hidden_state):
"""
We expand a node using the value, reward and policy prediction obtained from the
neural network.
"""
self.to_play = to_play
self.reward = reward
self.hidden_state = hidden_state
policy = {a: math.exp(policy_logits[0][a]) for a in actions}
......@@ -262,13 +286,12 @@ class GameHistory:
Store only usefull information of a self-play game.
"""
def __init__(self, discount):
def __init__(self):
self.observation_history = []
self.history = []
self.action_history = []
self.rewards = []
self.child_visits = []
self.root_values = []
self.discount = discount
def store_search_statistics(self, root, action_space):
sum_visits = sum(child.visit_count for child in root.children.values())
......
......@@ -7,7 +7,7 @@ import torch
import models
@ray.remote(num_gpus=1 if torch.cuda.is_available() else 0)
@ray.remote
class Trainer:
"""
Class which run in a dedicated thread to train a neural network and save it
......@@ -64,12 +64,7 @@ class Trainer:
"""
Perform one training step.
"""
# Update learning rate
lr = self.config.lr_init * self.config.lr_decay_rate ** (
self.training_step / self.config.lr_decay_steps
)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
self.update_lr()
(
observation_batch,
......@@ -116,9 +111,7 @@ class Trainer:
reward_loss += current_reward_loss
policy_loss += current_policy_loss
loss = (
value_loss + reward_loss + policy_loss
).mean()
loss = (value_loss + reward_loss + policy_loss).mean()
# Scale gradient by number of unroll steps (See paper Training appendix)
loss.register_hook(lambda grad: grad * 1 / self.config.num_unroll_steps)
......@@ -136,6 +129,16 @@ class Trainer:
policy_loss.mean().item(),
)
def update_lr(self):
"""
Update learning rate
"""
lr = self.config.lr_init * self.config.lr_decay_rate ** (
self.training_step / self.config.lr_decay_steps
)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy
......@@ -143,5 +146,7 @@ def loss_function(
# TODO: paper promotes cross entropy instead of MSE
value_loss = torch.nn.MSELoss()(value, target_value)
reward_loss = torch.nn.MSELoss()(reward, target_reward)
policy_loss = torch.mean(torch.sum(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits), 1))
policy_loss = torch.mean(
torch.sum(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits), 1)
)
return value_loss, reward_loss, policy_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment