Commit fe791e86 by Werner Duvaud

Add cross entropy, play against human, policy mask

parent 26f62896
......@@ -22,9 +22,10 @@ MuZero is a model based reinforcement learning algorithm, successor of AlphaZero
* [x] Easily adaptable for new games
* [x] [Examples](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py) of board and Gym games (See [list below](https://github.com/werner-duvaud/muzero-general#games-already-implemented-with-pretrained-network-available))
* [x] [Pretrained weights](https://github.com/werner-duvaud/muzero-general/tree/master/pretrained) available
* [ ] Play against MuZero mode with policy and value tracking
* [ ] Add human vs MuZero tracking in TensorBoard
* [ ] Residual Network
* [ ] Atari games
* [ ] Appendix Reanalyse of the paper
* [ ] Windows support ([workaround by ihexx](https://github.com/ihexx/muzero-general))
## Demo
......
......@@ -17,7 +17,7 @@ class MuZeroConfig:
### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 80 # Number of futur moves self-simulated
self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
......@@ -32,7 +32,8 @@ class MuZeroConfig:
### Network
self.encoding_size = 32
self.hidden_size = 64
self.hidden_layers = [64]
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size
### Training
......@@ -42,7 +43,7 @@ class MuZeroConfig:
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 30 # Number of steps in the futur to take into account for calculating the target value
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting
self.training_device = "cuda" if torch.cuda.is_available() else "cpu" # Train on GPU if available
......@@ -50,15 +51,16 @@ class MuZeroConfig:
self.momentum = 0.9
# Exponential learning rate schedule
self.lr_init = 0.008 # Initial learning rate
self.lr_init = 0.5 # Initial learning rate
self.lr_decay_rate = 1
self.lr_decay_steps = 10000
self.lr_decay_steps = 1000
### Test
self.test_episodes = 2 # Number of game played to evaluate the network
def visit_softmax_temperature_fn(self, trained_steps):
"""
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
......@@ -107,6 +109,19 @@ class Game:
"""
return 0
def legal_actions(self):
"""
Should return the legal actions at each turn, if it is not available, it can return
the whole action space. At each turn, the game have to be able to handle one of returned actions.
For complexe game where calculating legal moves is too long, the idea is to define the legal actions
equal to the action space but to return a negative reward if the action is illegal.
Returns:
An array of integers, subest of the action space.
"""
return [i for i in range(2)]
def reset(self):
"""
Reset the game for a new game.
......
......@@ -32,7 +32,8 @@ class MuZeroConfig:
### Network
self.encoding_size = 32
self.hidden_size = 64
self.hidden_layers = [64]
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size
### Training
......@@ -93,13 +94,7 @@ class Game:
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
if action not in self.env.legal_actions():
observation, reward, done = self.env.step(self.env.legal_actions()[0])
reward = -1
done = True
else:
observation, reward, done = self.env.step(action)
observation, reward, done = self.env.step(action)
return numpy.array(observation).flatten(), reward, done
def to_play(self):
......@@ -111,6 +106,19 @@ class Game:
"""
return self.env.to_play()
def legal_actions(self):
"""
Should return the legal actions at each turn, if it is not available, it can return
the whole action space. At each turn, the game have to be able to handle one of returned actions.
For complexe game where calculating legal moves is too long, the idea is to define the legal actions
equal to the action space but to return a negative reward if the action is illegal.
Returns:
An array of integers, subest of the action space.
"""
return self.env.legal_actions()
def reset(self):
"""
Reset the game for a new game.
......@@ -155,8 +163,12 @@ class Connect4:
done = self.is_finished()
reward = 1 if done and 0 < len(self.legal_actions()) else 0
self.player *= -1
return self.get_observation(), 1 if done else 0, done
return self.get_observation(), reward, done
def get_observation(self):
if self.player == 1:
......@@ -167,10 +179,8 @@ class Connect4:
def legal_actions(self):
legal = []
for i in range(7):
for j in range(6):
if self.board[j][i] == 0:
legal.append(i)
break
if self.board[5][i] == 0:
legal.append(i)
return legal
def is_finished(self):
......@@ -218,6 +228,9 @@ class Connect4:
):
return True
if len(self.legal_actions()) == 0:
return True
return False
def render(self):
......
......@@ -32,17 +32,18 @@ class MuZeroConfig:
### Network
self.encoding_size = 64
self.hidden_size = 128
self.hidden_layers = [128]
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size
### Training
self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 3000 # Total number of training steps (ie weights update according to a batch)
self.training_steps = 15000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 10 # Number of game moves to keep for every batch element
self.checkpoint_interval = 3 # Number of training steps before using the model for sef-playing
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.td_steps = 20 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting
self.training_device = "cuda" if torch.cuda.is_available() else "cpu" # Train on GPU if available
......@@ -50,9 +51,9 @@ class MuZeroConfig:
self.momentum = 0.9
# Exponential learning rate schedule
self.lr_init = 0.00005 # Initial learning rate
self.lr_decay_rate = 1
self.lr_decay_steps = 100000
self.lr_init = 0.11 # Initial learning rate
self.lr_decay_rate = 0.8
self.lr_decay_steps = 1000
### Test
......@@ -67,10 +68,8 @@ class MuZeroConfig:
Returns:
Positive float.
"""
if trained_steps < 0.2 * self.training_steps:
return float('inf')
if trained_steps < 0.5 * self.training_steps:
return 0.8
return 1.0
elif trained_steps < 0.75 * self.training_steps:
return 0.5
else:
......@@ -98,7 +97,7 @@ class Game:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
return numpy.array(observation).flatten(), reward, done
return numpy.array(observation).flatten(), reward/5, done
def to_play(self):
"""
......@@ -109,6 +108,19 @@ class Game:
"""
return 0
def legal_actions(self):
"""
Should return the legal actions at each turn, if it is not available, it can return
the whole action space. At each turn, the game have to be able to handle one of returned actions.
For complexe game where calculating legal moves is too long, the idea is to define the legal actions
equal to the action space but to return a negative reward if the action is illegal.
Returns:
An array of integers, subest of the action space.
"""
return [i for i in range(4)]
def reset(self):
"""
Reset the game for a new game.
......
......@@ -3,20 +3,21 @@ import torch
class FullyConnectedNetwork(torch.nn.Module):
def __init__(
self, input_size, layers_sizes, output_size, activation=torch.nn.Tanh()
self, input_size, layer_sizes, output_size, activation=torch.nn.Tanh()
):
super(FullyConnectedNetwork, self).__init__()
layers_sizes.insert(0, input_size)
sizes_list = layer_sizes.copy()
sizes_list.insert(0, input_size)
layers = []
if 1 < len(layers_sizes):
for i in range(len(layers_sizes) - 1):
if 1 < len(sizes_list):
for i in range(len(sizes_list) - 1):
layers.extend(
[
torch.nn.Linear(layers_sizes[i], layers_sizes[i + 1]),
torch.nn.Linear(sizes_list[i], sizes_list[i + 1]),
torch.nn.ReLU(),
]
)
layers.append(torch.nn.Linear(layers_sizes[-1], output_size))
layers.append(torch.nn.Linear(sizes_list[-1], output_size))
if activation:
layers.append(activation)
self.layers = torch.nn.ModuleList(layers)
......@@ -27,32 +28,42 @@ class FullyConnectedNetwork(torch.nn.Module):
return x
# TODO: unified residual network
class MuZeroNetwork(torch.nn.Module):
def __init__(self, observation_size, action_space_size, encoding_size, hidden_size):
def __init__(
self,
observation_size,
action_space_size,
encoding_size,
hidden_layers,
support_size,
):
super().__init__()
self.action_space_size = action_space_size
self.full_support_size = 2 * support_size + 1
self.representation_network = FullyConnectedNetwork(
observation_size, [], encoding_size
)
self.dynamics_encoded_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], encoding_size
encoding_size + self.action_space_size, hidden_layers, encoding_size
)
# Gradient scaling (See paper appendix Training)
self.dynamics_encoded_state_network.register_backward_hook(
lambda module, grad_i, grad_o: (grad_i[0] * 0.5,)
)
self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], 1, activation=None
encoding_size + self.action_space_size,
hidden_layers,
self.full_support_size,
activation=None,
)
self.prediction_policy_network = FullyConnectedNetwork(
encoding_size, [], self.action_space_size, activation=None
)
self.prediction_value_network = FullyConnectedNetwork(
encoding_size, [], 1, activation=None
encoding_size, [], self.full_support_size, activation=None
)
def prediction(self, encoded_state):
......@@ -93,7 +104,7 @@ class MuZeroNetwork(torch.nn.Module):
policy_logit, value = self.prediction(encoded_state)
return (
value,
torch.zeros(len(observation)).to(observation.device),
torch.zeros(len(observation), self.full_support_size).to(observation.device),
policy_logit,
encoded_state,
)
......
......@@ -56,7 +56,8 @@ class MuZero:
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
self.config.hidden_layers,
self.config.support_size
).get_weights()
def train(self):
......@@ -136,14 +137,20 @@ class MuZero:
time.sleep(3)
except KeyboardInterrupt as err:
# Comment the line below to be able to stop the training but keep running
raise err
# raise err
pass
self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
ray.shutdown()
def test(self, render=True):
def test(self, render, muzero_player):
"""
Test the model in a dedicated thread.
Args:
render : boolean to display or not the environment.
muzero_player : Integer with the player number of MuZero in case of multiplayer
games, None let MuZero play all players turn by turn.
"""
print("\nTesting...")
ray.init()
......@@ -152,7 +159,7 @@ class MuZero:
)
test_rewards = []
for _ in range(self.config.test_episodes):
history = ray.get(self_play_workers.play_game.remote(0, render))
history = ray.get(self_play_workers.play_game.remote(0, render, muzero_player))
test_rewards.append(sum(history.rewards))
ray.shutdown()
return test_rewards
......@@ -168,8 +175,15 @@ class MuZero:
if __name__ == "__main__":
# Use the game and config from the ./games folder
muzero = MuZero("cartpole")
## Train
muzero.train()
## Test
muzero.load_model()
muzero.test()
# Render some self-played games
muzero.test(render=True, muzero_player=None)
# Let user play against MuZero (MuZero is player 0 here)
# muzero.test(render=True, muzero_player=0)
No preview for this file type
No preview for this file type
......@@ -31,10 +31,10 @@ class ReplayBuffer:
[],
)
for _ in range(self.config.batch_size):
game_history = sample_game(self.buffer)
game_pos = sample_position(game_history)
game_history = self.sample_game(self.buffer)
game_pos = self.sample_position(game_history)
value, reward, policy, actions = make_target(
value, reward, policy, actions = self.make_target(
game_history,
game_pos,
self.config.num_unroll_steps,
......@@ -50,58 +50,57 @@ class ReplayBuffer:
return observation_batch, action_batch, value_batch, reward_batch, policy_batch
def sample_game(buffer):
"""
Sample game from buffer either uniformly or according to some priority.
"""
# TODO: sample with probability link to the highest difference between real and
# predicted value (See paper appendix Training)
return numpy.random.choice(buffer)
def sample_position(game_history):
"""
Sample position from game either uniformly or according to some priority.
"""
# TODO: sample according to some priority
return numpy.random.choice(range(len(game_history.rewards)))
def make_target(game_history, state_index, num_unroll_steps, td_steps, discount):
"""
The value target is the discounted root value of the search tree td_steps into the
future, plus the discounted sum of all rewards until then.
"""
target_values, target_rewards, target_policies, actions = [], [], [], []
for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps
if bootstrap_index < len(game_history.root_values):
value = game_history.root_values[bootstrap_index] * discount ** td_steps
else:
value = 0
for i, reward in enumerate(game_history.rewards[current_index:bootstrap_index]):
value += reward * discount ** i
if current_index < len(game_history.root_values):
# TODO: Scale and transform the reward and the value, don't forget the invert transform in inference time (See paper appendix Network Architecture)
# Value target could be scaled by 0.25 (See paper appendix Reanalyze)
target_values.append(value)
target_rewards.append(game_history.rewards[current_index])
target_policies.append(game_history.child_visits[current_index])
actions.append(game_history.action_history[current_index])
else:
# States past the end of games are treated as absorbing states
target_values.append(0)
target_rewards.append(0)
# Uniform policy to give the tensor a valid dimension
target_policies.append(
[
1 / len(game_history.child_visits[0])
for _ in range(len(game_history.child_visits[0]))
]
)
actions.append(game_history.action_history[-1])
return target_values, target_rewards, target_policies, actions
@staticmethod
def sample_game(buffer):
"""
Sample game from buffer either uniformly or according to some priority.
"""
# TODO: sample with probability link to the highest difference between real and
# predicted value (See paper appendix Training)
return numpy.random.choice(buffer)
@staticmethod
def sample_position(game_history):
"""
Sample position from game either uniformly or according to some priority.
"""
# TODO: sample according to some priority
return numpy.random.choice(range(len(game_history.rewards)))
@staticmethod
def make_target(game_history, state_index, num_unroll_steps, td_steps, discount):
"""
The value target is the discounted root value of the search tree td_steps into the
future, plus the discounted sum of all rewards until then.
"""
target_values, target_rewards, target_policies, actions = [], [], [], []
for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps
if bootstrap_index < len(game_history.root_values):
value = game_history.root_values[bootstrap_index] * discount ** td_steps
else:
value = 0
for i, reward in enumerate(game_history.rewards[current_index:bootstrap_index]):
value += reward * discount ** i
if current_index < len(game_history.root_values):
# Value target could be scaled by 0.25 (See paper appendix Reanalyze)
target_values.append(value)
target_rewards.append(game_history.rewards[current_index])
target_policies.append(game_history.child_visits[current_index])
actions.append(game_history.action_history[current_index])
else:
# States past the end of games are treated as absorbing states
target_values.append(0)
target_rewards.append(0)
# Uniform policy to give the tensor a valid dimension
target_policies.append(
[
1 / len(game_history.child_visits[0])
for _ in range(len(game_history.child_visits[0]))
]
)
actions.append(game_history.action_history[-1])
return target_values, target_rewards, target_policies, actions
......@@ -24,7 +24,8 @@ class SelfPlay:
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
self.config.hidden_layers,
self.config.support_size,
)
self.model.set_weights(initial_weights)
self.model.to(torch.device("cpu"))
......@@ -46,7 +47,7 @@ class SelfPlay:
]
)
)
game_history = self.play_game(temperature, False)
game_history = self.play_game(temperature, False, None)
# Save to the shared storage
if test_mode:
......@@ -59,7 +60,7 @@ class SelfPlay:
if not test_mode and self.config.self_play_delay:
time.sleep(self.config.self_play_delay)
def play_game(self, temperature, render):
def play_game(self, temperature, render, play_against_human_player):
"""
Play one game with actions based on the Monte Carlo tree search at each moves.
"""
......@@ -67,20 +68,32 @@ class SelfPlay:
observation = self.game.reset()
game_history.observation_history.append(observation)
done = False
if render:
self.game.render()
with torch.no_grad():
while not done and len(game_history.action_history) < self.config.max_moves:
root = MCTS(self.config).run(
self.model,
observation,
self.game.to_play(),
False if temperature == 0 else True,
)
current_player = self.game.to_play()
if play_against_human_player is None or play_against_human_player == current_player:
root = MCTS(self.config).run(
self.model,
observation,
self.game.legal_actions(),
current_player,
False if temperature == 0 else True,
)
action = select_action(root, temperature)
action = self.select_action(root, temperature)
observation, reward, done = self.game.step(action)
observation, reward, done = self.game.step(action)
if play_against_human_player is not None and current_player != play_against_human_player:
action = int(input("Enter the action of player {} : ".format(current_player)))
observation, reward, done = self.game.step(action)
if render:
print("Action : {}".format(action))
self.game.render()
game_history.observation_history.append(observation)
......@@ -91,31 +104,30 @@ class SelfPlay:
self.game.close()
return game_history
def select_action(node, temperature):
"""
Select action according to the vivist count distribution and the temperature.
The temperature is changed dynamically with the visit_softmax_temperature function
in the config.
"""
visit_counts = numpy.array(
[[child.visit_count, action] for action, child in node.children.items()]
).T
if temperature == 0:
action_pos = numpy.argmax(visit_counts[0])
elif temperature == float("inf"):
action_pos = numpy.random.choice(len(visit_counts[1]))
else:
# See paper appendix Data Generation
visit_count_distribution = visit_counts[0] ** (1 / temperature)
visit_count_distribution = visit_count_distribution / sum(
visit_count_distribution
)
action_pos = numpy.random.choice(
len(visit_counts[1]), p=visit_count_distribution
@staticmethod
def select_action(node, temperature):
"""
Select action according to the vivist count distribution and the temperature.
The temperature is changed dynamically with the visit_softmax_temperature function
in the config.
"""
visit_counts = numpy.array(
[child.visit_count for child in node.children.values()]
)
actions = [action for action in node.children.keys()]
if temperature == 0:
action = actions[numpy.argmax(visit_counts)]
elif temperature == float("inf"):
action = numpy.random.choice(actions)
else:
# See paper appendix Data Generation
visit_count_distribution = visit_counts ** (1 / temperature)
visit_count_distribution = visit_count_distribution / sum(
visit_count_distribution
)
action = numpy.random.choice(actions, p=visit_count_distribution)
return visit_counts[1][action_pos]
return action
# Game independant
......@@ -130,7 +142,7 @@ class MCTS:
def __init__(self, config):
self.config = config
def run(self, model, observation, to_play, add_exploration_noise):
def run(self, model, observation, legal_actions, to_play, add_exploration_noise):
"""
At the root of the search tree we use the representation function to obtain a
hidden state given the current observation.
......@@ -147,12 +159,11 @@ class MCTS:
_, expected_reward, policy_logits, hidden_state = model.initial_inference(
observation
)
expected_reward = self.support_to_scalar(
expected_reward, self.config.support_size
)
root.expand(
self.config.action_space,
to_play,
expected_reward,
policy_logits,
hidden_state,
legal_actions, to_play, expected_reward, policy_logits, hidden_state,
)
if add_exploration_noise:
root.add_exploration_noise(
......@@ -183,8 +194,10 @@ class MCTS:
parent = search_path[-2]
value, reward, policy_logits, hidden_state = model.recurrent_inference(
parent.hidden_state,
torch.tensor([[last_action]]).to(parent.hidden_state.device),
torch.tensor([last_action]).unsqueeze(1).to(parent.hidden_state.device),
)
value = self.support_to_scalar(value, self.config.support_size)
reward = self.support_to_scalar(reward, self.config.support_size)
node.expand(
self.config.action_space,
virtual_to_play,
......@@ -236,6 +249,29 @@ class MCTS:
value = node.reward + self.config.discount * value
@staticmethod
def support_to_scalar(logits, support_size):
"""
Transform a categorical representation to a scalar
See paper appendix Network Architecture
"""
# Decode to a scalar
probs = torch.softmax(logits, dim=1)
support = (
torch.tensor([x for x in range(-support_size, support_size + 1)])
.expand(probs.shape)
.to(device=probs.device)
)
x = torch.sum(support * probs, dim=1, keepdim=True)
# Invert the scaling (defined in https://arxiv.org/abs/1805.11593)
x = torch.sign(x) * (
((torch.sqrt(1 + 4 * 0.001 * (torch.abs(x) + 1 + 0.001)) - 1) / (2 * 0.001))
** 2
- 1
)
return x
class Node:
def __init__(self, prior):
......@@ -294,7 +330,6 @@ class GameHistory:
def store_search_statistics(self, root, action_space):
sum_visits = sum(child.visit_count for child in root.children.values())
# TODO: action could be of any type, not only integers
self.child_visits.append(
[
root.children[a].visit_count / sum_visits if a in root.children else 0
......
......@@ -23,7 +23,8 @@ class Trainer:
self.config.observation_shape,
len(self.config.action_space),
self.config.encoding_size,
self.config.hidden_size,
self.config.hidden_layers,
self.config.support_size
)
self.model.set_weights(initial_weights)
self.model.to(torch.device(config.training_device))
......@@ -81,6 +82,9 @@ class Trainer:
target_reward = torch.tensor(target_reward).float().to(device)
target_policy = torch.tensor(target_policy).float().to(device)
target_value = self.scalar_to_support(target_value, self.config.support_size)
target_reward = self.scalar_to_support(target_reward, self.config.support_size)
value, reward, policy_logits, hidden_state = self.model.initial_inference(
observation_batch
)
......@@ -99,13 +103,13 @@ class Trainer:
current_value_loss,
current_reward_loss,
current_policy_loss,
) = loss_function(
) = self.loss_function(
value.squeeze(-1),
reward.squeeze(-1),
policy_logits,
target_value[:, i],
target_reward[:, i],
target_policy[:, i, :],
target_policy[:, i],
)
value_loss += current_value_loss
reward_loss += current_reward_loss
......@@ -139,14 +143,35 @@ class Trainer:
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy
):
# TODO: paper promotes cross entropy instead of MSE
value_loss = torch.nn.MSELoss()(value, target_value)
reward_loss = torch.nn.MSELoss()(reward, target_reward)
policy_loss = torch.mean(
torch.sum(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits), 1)
)
return value_loss, reward_loss, policy_loss
@staticmethod
def scalar_to_support(x, support_size):
"""
Transform a scalar to a categorical representation with (2 * support_size + 1) categories
See paper appendix Network Architecture
"""
# Reduce the scale (defined in https://arxiv.org/abs/1805.11593)
x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1 + 0.001 * x)
# Encode on a vector
x = torch.clamp(x, -support_size, support_size)
floor = x.floor()
ceil = x.ceil()
prob = x - floor
logits = torch.zeros(x.shape[0], x.shape[1], 2 * support_size + 1).to(x.device)
logits.scatter_(
2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1)
)
logits.scatter_(
2, (ceil + support_size).long().unsqueeze(-1), prob.unsqueeze(-1)
)
return logits
@staticmethod
def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy
):
# Cross-entropy had a better convergence than MSE
value_loss = (-target_value * torch.nn.LogSoftmax(dim=1)(value)).sum(1).mean()
reward_loss = (-target_reward * torch.nn.LogSoftmax(dim=1)(reward)).sum(1).mean()
policy_loss = (-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum(1).mean()
return value_loss, reward_loss, policy_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment