Commit fe791e86 by Werner Duvaud

Add cross entropy, play against human, policy mask

parent 26f62896
...@@ -22,9 +22,10 @@ MuZero is a model based reinforcement learning algorithm, successor of AlphaZero ...@@ -22,9 +22,10 @@ MuZero is a model based reinforcement learning algorithm, successor of AlphaZero
* [x] Easily adaptable for new games * [x] Easily adaptable for new games
* [x] [Examples](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py) of board and Gym games (See [list below](https://github.com/werner-duvaud/muzero-general#games-already-implemented-with-pretrained-network-available)) * [x] [Examples](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py) of board and Gym games (See [list below](https://github.com/werner-duvaud/muzero-general#games-already-implemented-with-pretrained-network-available))
* [x] [Pretrained weights](https://github.com/werner-duvaud/muzero-general/tree/master/pretrained) available * [x] [Pretrained weights](https://github.com/werner-duvaud/muzero-general/tree/master/pretrained) available
* [ ] Play against MuZero mode with policy and value tracking * [ ] Add human vs MuZero tracking in TensorBoard
* [ ] Residual Network * [ ] Residual Network
* [ ] Atari games * [ ] Atari games
* [ ] Appendix Reanalyse of the paper
* [ ] Windows support ([workaround by ihexx](https://github.com/ihexx/muzero-general)) * [ ] Windows support ([workaround by ihexx](https://github.com/ihexx/muzero-general))
## Demo ## Demo
......
...@@ -17,7 +17,7 @@ class MuZeroConfig: ...@@ -17,7 +17,7 @@ class MuZeroConfig:
### Self-Play ### Self-Play
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 500 # Maximum number of moves if game is not finished before self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 80 # Number of futur moves self-simulated self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward self.discount = 0.997 # Chronological discount of the reward
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting
...@@ -32,7 +32,8 @@ class MuZeroConfig: ...@@ -32,7 +32,8 @@ class MuZeroConfig:
### Network ### Network
self.encoding_size = 32 self.encoding_size = 32
self.hidden_size = 64 self.hidden_layers = [64]
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size
### Training ### Training
...@@ -42,7 +43,7 @@ class MuZeroConfig: ...@@ -42,7 +43,7 @@ class MuZeroConfig:
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 30 # Number of steps in the futur to take into account for calculating the target value self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting
self.training_device = "cuda" if torch.cuda.is_available() else "cpu" # Train on GPU if available self.training_device = "cuda" if torch.cuda.is_available() else "cpu" # Train on GPU if available
...@@ -50,15 +51,16 @@ class MuZeroConfig: ...@@ -50,15 +51,16 @@ class MuZeroConfig:
self.momentum = 0.9 self.momentum = 0.9
# Exponential learning rate schedule # Exponential learning rate schedule
self.lr_init = 0.008 # Initial learning rate self.lr_init = 0.5 # Initial learning rate
self.lr_decay_rate = 1 self.lr_decay_rate = 1
self.lr_decay_steps = 10000 self.lr_decay_steps = 1000
### Test ### Test
self.test_episodes = 2 # Number of game played to evaluate the network self.test_episodes = 2 # Number of game played to evaluate the network
def visit_softmax_temperature_fn(self, trained_steps): def visit_softmax_temperature_fn(self, trained_steps):
""" """
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses. Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
...@@ -107,6 +109,19 @@ class Game: ...@@ -107,6 +109,19 @@ class Game:
""" """
return 0 return 0
def legal_actions(self):
"""
Should return the legal actions at each turn, if it is not available, it can return
the whole action space. At each turn, the game have to be able to handle one of returned actions.
For complexe game where calculating legal moves is too long, the idea is to define the legal actions
equal to the action space but to return a negative reward if the action is illegal.
Returns:
An array of integers, subest of the action space.
"""
return [i for i in range(2)]
def reset(self): def reset(self):
""" """
Reset the game for a new game. Reset the game for a new game.
......
...@@ -32,7 +32,8 @@ class MuZeroConfig: ...@@ -32,7 +32,8 @@ class MuZeroConfig:
### Network ### Network
self.encoding_size = 32 self.encoding_size = 32
self.hidden_size = 64 self.hidden_layers = [64]
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size
### Training ### Training
...@@ -93,13 +94,7 @@ class Game: ...@@ -93,13 +94,7 @@ class Game:
Returns: Returns:
The new observation, the reward and a boolean if the game has ended. The new observation, the reward and a boolean if the game has ended.
""" """
if action not in self.env.legal_actions():
observation, reward, done = self.env.step(self.env.legal_actions()[0])
reward = -1
done = True
else:
observation, reward, done = self.env.step(action) observation, reward, done = self.env.step(action)
return numpy.array(observation).flatten(), reward, done return numpy.array(observation).flatten(), reward, done
def to_play(self): def to_play(self):
...@@ -111,6 +106,19 @@ class Game: ...@@ -111,6 +106,19 @@ class Game:
""" """
return self.env.to_play() return self.env.to_play()
def legal_actions(self):
"""
Should return the legal actions at each turn, if it is not available, it can return
the whole action space. At each turn, the game have to be able to handle one of returned actions.
For complexe game where calculating legal moves is too long, the idea is to define the legal actions
equal to the action space but to return a negative reward if the action is illegal.
Returns:
An array of integers, subest of the action space.
"""
return self.env.legal_actions()
def reset(self): def reset(self):
""" """
Reset the game for a new game. Reset the game for a new game.
...@@ -155,8 +163,12 @@ class Connect4: ...@@ -155,8 +163,12 @@ class Connect4:
done = self.is_finished() done = self.is_finished()
reward = 1 if done and 0 < len(self.legal_actions()) else 0
self.player *= -1 self.player *= -1
return self.get_observation(), 1 if done else 0, done
return self.get_observation(), reward, done
def get_observation(self): def get_observation(self):
if self.player == 1: if self.player == 1:
...@@ -167,10 +179,8 @@ class Connect4: ...@@ -167,10 +179,8 @@ class Connect4:
def legal_actions(self): def legal_actions(self):
legal = [] legal = []
for i in range(7): for i in range(7):
for j in range(6): if self.board[5][i] == 0:
if self.board[j][i] == 0:
legal.append(i) legal.append(i)
break
return legal return legal
def is_finished(self): def is_finished(self):
...@@ -218,6 +228,9 @@ class Connect4: ...@@ -218,6 +228,9 @@ class Connect4:
): ):
return True return True
if len(self.legal_actions()) == 0:
return True
return False return False
def render(self): def render(self):
......
...@@ -32,17 +32,18 @@ class MuZeroConfig: ...@@ -32,17 +32,18 @@ class MuZeroConfig:
### Network ### Network
self.encoding_size = 64 self.encoding_size = 64
self.hidden_size = 128 self.hidden_layers = [128]
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size
### Training ### Training
self.results_path = "./pretrained" # Path to store the model weights self.results_path = "./pretrained" # Path to store the model weights
self.training_steps = 3000 # Total number of training steps (ie weights update according to a batch) self.training_steps = 15000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 10 # Number of game moves to keep for every batch element self.num_unroll_steps = 10 # Number of game moves to keep for every batch element
self.checkpoint_interval = 3 # Number of training steps before using the model for sef-playing self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 10 # Number of steps in the futur to take into account for calculating the target value self.td_steps = 20 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting
self.training_device = "cuda" if torch.cuda.is_available() else "cpu" # Train on GPU if available self.training_device = "cuda" if torch.cuda.is_available() else "cpu" # Train on GPU if available
...@@ -50,9 +51,9 @@ class MuZeroConfig: ...@@ -50,9 +51,9 @@ class MuZeroConfig:
self.momentum = 0.9 self.momentum = 0.9
# Exponential learning rate schedule # Exponential learning rate schedule
self.lr_init = 0.00005 # Initial learning rate self.lr_init = 0.11 # Initial learning rate
self.lr_decay_rate = 1 self.lr_decay_rate = 0.8
self.lr_decay_steps = 100000 self.lr_decay_steps = 1000
### Test ### Test
...@@ -67,10 +68,8 @@ class MuZeroConfig: ...@@ -67,10 +68,8 @@ class MuZeroConfig:
Returns: Returns:
Positive float. Positive float.
""" """
if trained_steps < 0.2 * self.training_steps:
return float('inf')
if trained_steps < 0.5 * self.training_steps: if trained_steps < 0.5 * self.training_steps:
return 0.8 return 1.0
elif trained_steps < 0.75 * self.training_steps: elif trained_steps < 0.75 * self.training_steps:
return 0.5 return 0.5
else: else:
...@@ -98,7 +97,7 @@ class Game: ...@@ -98,7 +97,7 @@ class Game:
The new observation, the reward and a boolean if the game has ended. The new observation, the reward and a boolean if the game has ended.
""" """
observation, reward, done, _ = self.env.step(action) observation, reward, done, _ = self.env.step(action)
return numpy.array(observation).flatten(), reward, done return numpy.array(observation).flatten(), reward/5, done
def to_play(self): def to_play(self):
""" """
...@@ -109,6 +108,19 @@ class Game: ...@@ -109,6 +108,19 @@ class Game:
""" """
return 0 return 0
def legal_actions(self):
"""
Should return the legal actions at each turn, if it is not available, it can return
the whole action space. At each turn, the game have to be able to handle one of returned actions.
For complexe game where calculating legal moves is too long, the idea is to define the legal actions
equal to the action space but to return a negative reward if the action is illegal.
Returns:
An array of integers, subest of the action space.
"""
return [i for i in range(4)]
def reset(self): def reset(self):
""" """
Reset the game for a new game. Reset the game for a new game.
......
...@@ -3,20 +3,21 @@ import torch ...@@ -3,20 +3,21 @@ import torch
class FullyConnectedNetwork(torch.nn.Module): class FullyConnectedNetwork(torch.nn.Module):
def __init__( def __init__(
self, input_size, layers_sizes, output_size, activation=torch.nn.Tanh() self, input_size, layer_sizes, output_size, activation=torch.nn.Tanh()
): ):
super(FullyConnectedNetwork, self).__init__() super(FullyConnectedNetwork, self).__init__()
layers_sizes.insert(0, input_size) sizes_list = layer_sizes.copy()
sizes_list.insert(0, input_size)
layers = [] layers = []
if 1 < len(layers_sizes): if 1 < len(sizes_list):
for i in range(len(layers_sizes) - 1): for i in range(len(sizes_list) - 1):
layers.extend( layers.extend(
[ [
torch.nn.Linear(layers_sizes[i], layers_sizes[i + 1]), torch.nn.Linear(sizes_list[i], sizes_list[i + 1]),
torch.nn.ReLU(), torch.nn.ReLU(),
] ]
) )
layers.append(torch.nn.Linear(layers_sizes[-1], output_size)) layers.append(torch.nn.Linear(sizes_list[-1], output_size))
if activation: if activation:
layers.append(activation) layers.append(activation)
self.layers = torch.nn.ModuleList(layers) self.layers = torch.nn.ModuleList(layers)
...@@ -27,32 +28,42 @@ class FullyConnectedNetwork(torch.nn.Module): ...@@ -27,32 +28,42 @@ class FullyConnectedNetwork(torch.nn.Module):
return x return x
# TODO: unified residual network
class MuZeroNetwork(torch.nn.Module): class MuZeroNetwork(torch.nn.Module):
def __init__(self, observation_size, action_space_size, encoding_size, hidden_size): def __init__(
self,
observation_size,
action_space_size,
encoding_size,
hidden_layers,
support_size,
):
super().__init__() super().__init__()
self.action_space_size = action_space_size self.action_space_size = action_space_size
self.full_support_size = 2 * support_size + 1
self.representation_network = FullyConnectedNetwork( self.representation_network = FullyConnectedNetwork(
observation_size, [], encoding_size observation_size, [], encoding_size
) )
self.dynamics_encoded_state_network = FullyConnectedNetwork( self.dynamics_encoded_state_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], encoding_size encoding_size + self.action_space_size, hidden_layers, encoding_size
) )
# Gradient scaling (See paper appendix Training) # Gradient scaling (See paper appendix Training)
self.dynamics_encoded_state_network.register_backward_hook( self.dynamics_encoded_state_network.register_backward_hook(
lambda module, grad_i, grad_o: (grad_i[0] * 0.5,) lambda module, grad_i, grad_o: (grad_i[0] * 0.5,)
) )
self.dynamics_reward_network = FullyConnectedNetwork( self.dynamics_reward_network = FullyConnectedNetwork(
encoding_size + self.action_space_size, [hidden_size], 1, activation=None encoding_size + self.action_space_size,
hidden_layers,
self.full_support_size,
activation=None,
) )
self.prediction_policy_network = FullyConnectedNetwork( self.prediction_policy_network = FullyConnectedNetwork(
encoding_size, [], self.action_space_size, activation=None encoding_size, [], self.action_space_size, activation=None
) )
self.prediction_value_network = FullyConnectedNetwork( self.prediction_value_network = FullyConnectedNetwork(
encoding_size, [], 1, activation=None encoding_size, [], self.full_support_size, activation=None
) )
def prediction(self, encoded_state): def prediction(self, encoded_state):
...@@ -93,7 +104,7 @@ class MuZeroNetwork(torch.nn.Module): ...@@ -93,7 +104,7 @@ class MuZeroNetwork(torch.nn.Module):
policy_logit, value = self.prediction(encoded_state) policy_logit, value = self.prediction(encoded_state)
return ( return (
value, value,
torch.zeros(len(observation)).to(observation.device), torch.zeros(len(observation), self.full_support_size).to(observation.device),
policy_logit, policy_logit,
encoded_state, encoded_state,
) )
......
...@@ -56,7 +56,8 @@ class MuZero: ...@@ -56,7 +56,8 @@ class MuZero:
self.config.observation_shape, self.config.observation_shape,
len(self.config.action_space), len(self.config.action_space),
self.config.encoding_size, self.config.encoding_size,
self.config.hidden_size, self.config.hidden_layers,
self.config.support_size
).get_weights() ).get_weights()
def train(self): def train(self):
...@@ -136,14 +137,20 @@ class MuZero: ...@@ -136,14 +137,20 @@ class MuZero:
time.sleep(3) time.sleep(3)
except KeyboardInterrupt as err: except KeyboardInterrupt as err:
# Comment the line below to be able to stop the training but keep running # Comment the line below to be able to stop the training but keep running
raise err # raise err
pass pass
self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote()) self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
ray.shutdown() ray.shutdown()
def test(self, render=True): def test(self, render, muzero_player):
""" """
Test the model in a dedicated thread. Test the model in a dedicated thread.
Args:
render : boolean to display or not the environment.
muzero_player : Integer with the player number of MuZero in case of multiplayer
games, None let MuZero play all players turn by turn.
""" """
print("\nTesting...") print("\nTesting...")
ray.init() ray.init()
...@@ -152,7 +159,7 @@ class MuZero: ...@@ -152,7 +159,7 @@ class MuZero:
) )
test_rewards = [] test_rewards = []
for _ in range(self.config.test_episodes): for _ in range(self.config.test_episodes):
history = ray.get(self_play_workers.play_game.remote(0, render)) history = ray.get(self_play_workers.play_game.remote(0, render, muzero_player))
test_rewards.append(sum(history.rewards)) test_rewards.append(sum(history.rewards))
ray.shutdown() ray.shutdown()
return test_rewards return test_rewards
...@@ -168,8 +175,15 @@ class MuZero: ...@@ -168,8 +175,15 @@ class MuZero:
if __name__ == "__main__": if __name__ == "__main__":
# Use the game and config from the ./games folder
muzero = MuZero("cartpole") muzero = MuZero("cartpole")
## Train
muzero.train() muzero.train()
## Test
muzero.load_model() muzero.load_model()
muzero.test() # Render some self-played games
muzero.test(render=True, muzero_player=None)
# Let user play against MuZero (MuZero is player 0 here)
# muzero.test(render=True, muzero_player=0)
No preview for this file type
No preview for this file type
...@@ -31,10 +31,10 @@ class ReplayBuffer: ...@@ -31,10 +31,10 @@ class ReplayBuffer:
[], [],
) )
for _ in range(self.config.batch_size): for _ in range(self.config.batch_size):
game_history = sample_game(self.buffer) game_history = self.sample_game(self.buffer)
game_pos = sample_position(game_history) game_pos = self.sample_position(game_history)
value, reward, policy, actions = make_target( value, reward, policy, actions = self.make_target(
game_history, game_history,
game_pos, game_pos,
self.config.num_unroll_steps, self.config.num_unroll_steps,
...@@ -50,8 +50,8 @@ class ReplayBuffer: ...@@ -50,8 +50,8 @@ class ReplayBuffer:
return observation_batch, action_batch, value_batch, reward_batch, policy_batch return observation_batch, action_batch, value_batch, reward_batch, policy_batch
@staticmethod
def sample_game(buffer): def sample_game(buffer):
""" """
Sample game from buffer either uniformly or according to some priority. Sample game from buffer either uniformly or according to some priority.
""" """
...@@ -59,16 +59,16 @@ def sample_game(buffer): ...@@ -59,16 +59,16 @@ def sample_game(buffer):
# predicted value (See paper appendix Training) # predicted value (See paper appendix Training)
return numpy.random.choice(buffer) return numpy.random.choice(buffer)
@staticmethod
def sample_position(game_history): def sample_position(game_history):
""" """
Sample position from game either uniformly or according to some priority. Sample position from game either uniformly or according to some priority.
""" """
# TODO: sample according to some priority # TODO: sample according to some priority
return numpy.random.choice(range(len(game_history.rewards))) return numpy.random.choice(range(len(game_history.rewards)))
@staticmethod
def make_target(game_history, state_index, num_unroll_steps, td_steps, discount): def make_target(game_history, state_index, num_unroll_steps, td_steps, discount):
""" """
The value target is the discounted root value of the search tree td_steps into the The value target is the discounted root value of the search tree td_steps into the
future, plus the discounted sum of all rewards until then. future, plus the discounted sum of all rewards until then.
...@@ -85,7 +85,6 @@ def make_target(game_history, state_index, num_unroll_steps, td_steps, discount) ...@@ -85,7 +85,6 @@ def make_target(game_history, state_index, num_unroll_steps, td_steps, discount)
value += reward * discount ** i value += reward * discount ** i
if current_index < len(game_history.root_values): if current_index < len(game_history.root_values):
# TODO: Scale and transform the reward and the value, don't forget the invert transform in inference time (See paper appendix Network Architecture)
# Value target could be scaled by 0.25 (See paper appendix Reanalyze) # Value target could be scaled by 0.25 (See paper appendix Reanalyze)
target_values.append(value) target_values.append(value)
target_rewards.append(game_history.rewards[current_index]) target_rewards.append(game_history.rewards[current_index])
......
...@@ -24,7 +24,8 @@ class SelfPlay: ...@@ -24,7 +24,8 @@ class SelfPlay:
self.config.observation_shape, self.config.observation_shape,
len(self.config.action_space), len(self.config.action_space),
self.config.encoding_size, self.config.encoding_size,
self.config.hidden_size, self.config.hidden_layers,
self.config.support_size,
) )
self.model.set_weights(initial_weights) self.model.set_weights(initial_weights)
self.model.to(torch.device("cpu")) self.model.to(torch.device("cpu"))
...@@ -46,7 +47,7 @@ class SelfPlay: ...@@ -46,7 +47,7 @@ class SelfPlay:
] ]
) )
) )
game_history = self.play_game(temperature, False) game_history = self.play_game(temperature, False, None)
# Save to the shared storage # Save to the shared storage
if test_mode: if test_mode:
...@@ -59,7 +60,7 @@ class SelfPlay: ...@@ -59,7 +60,7 @@ class SelfPlay:
if not test_mode and self.config.self_play_delay: if not test_mode and self.config.self_play_delay:
time.sleep(self.config.self_play_delay) time.sleep(self.config.self_play_delay)
def play_game(self, temperature, render): def play_game(self, temperature, render, play_against_human_player):
""" """
Play one game with actions based on the Monte Carlo tree search at each moves. Play one game with actions based on the Monte Carlo tree search at each moves.
""" """
...@@ -67,20 +68,32 @@ class SelfPlay: ...@@ -67,20 +68,32 @@ class SelfPlay:
observation = self.game.reset() observation = self.game.reset()
game_history.observation_history.append(observation) game_history.observation_history.append(observation)
done = False done = False
if render:
self.game.render()
with torch.no_grad(): with torch.no_grad():
while not done and len(game_history.action_history) < self.config.max_moves: while not done and len(game_history.action_history) < self.config.max_moves:
current_player = self.game.to_play()
if play_against_human_player is None or play_against_human_player == current_player:
root = MCTS(self.config).run( root = MCTS(self.config).run(
self.model, self.model,
observation, observation,
self.game.to_play(), self.game.legal_actions(),
current_player,
False if temperature == 0 else True, False if temperature == 0 else True,
) )
action = select_action(root, temperature) action = self.select_action(root, temperature)
observation, reward, done = self.game.step(action)
if play_against_human_player is not None and current_player != play_against_human_player:
action = int(input("Enter the action of player {} : ".format(current_player)))
observation, reward, done = self.game.step(action) observation, reward, done = self.game.step(action)
if render: if render:
print("Action : {}".format(action))
self.game.render() self.game.render()
game_history.observation_history.append(observation) game_history.observation_history.append(observation)
...@@ -91,31 +104,30 @@ class SelfPlay: ...@@ -91,31 +104,30 @@ class SelfPlay:
self.game.close() self.game.close()
return game_history return game_history
@staticmethod
def select_action(node, temperature): def select_action(node, temperature):
""" """
Select action according to the vivist count distribution and the temperature. Select action according to the vivist count distribution and the temperature.
The temperature is changed dynamically with the visit_softmax_temperature function The temperature is changed dynamically with the visit_softmax_temperature function
in the config. in the config.
""" """
visit_counts = numpy.array( visit_counts = numpy.array(
[[child.visit_count, action] for action, child in node.children.items()] [child.visit_count for child in node.children.values()]
).T )
actions = [action for action in node.children.keys()]
if temperature == 0: if temperature == 0:
action_pos = numpy.argmax(visit_counts[0]) action = actions[numpy.argmax(visit_counts)]
elif temperature == float("inf"): elif temperature == float("inf"):
action_pos = numpy.random.choice(len(visit_counts[1])) action = numpy.random.choice(actions)
else: else:
# See paper appendix Data Generation # See paper appendix Data Generation
visit_count_distribution = visit_counts[0] ** (1 / temperature) visit_count_distribution = visit_counts ** (1 / temperature)
visit_count_distribution = visit_count_distribution / sum( visit_count_distribution = visit_count_distribution / sum(
visit_count_distribution visit_count_distribution
) )
action_pos = numpy.random.choice( action = numpy.random.choice(actions, p=visit_count_distribution)
len(visit_counts[1]), p=visit_count_distribution
)
return visit_counts[1][action_pos] return action
# Game independant # Game independant
...@@ -130,7 +142,7 @@ class MCTS: ...@@ -130,7 +142,7 @@ class MCTS:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
def run(self, model, observation, to_play, add_exploration_noise): def run(self, model, observation, legal_actions, to_play, add_exploration_noise):
""" """
At the root of the search tree we use the representation function to obtain a At the root of the search tree we use the representation function to obtain a
hidden state given the current observation. hidden state given the current observation.
...@@ -147,12 +159,11 @@ class MCTS: ...@@ -147,12 +159,11 @@ class MCTS:
_, expected_reward, policy_logits, hidden_state = model.initial_inference( _, expected_reward, policy_logits, hidden_state = model.initial_inference(
observation observation
) )
expected_reward = self.support_to_scalar(
expected_reward, self.config.support_size
)
root.expand( root.expand(
self.config.action_space, legal_actions, to_play, expected_reward, policy_logits, hidden_state,
to_play,
expected_reward,
policy_logits,
hidden_state,
) )
if add_exploration_noise: if add_exploration_noise:
root.add_exploration_noise( root.add_exploration_noise(
...@@ -183,8 +194,10 @@ class MCTS: ...@@ -183,8 +194,10 @@ class MCTS:
parent = search_path[-2] parent = search_path[-2]
value, reward, policy_logits, hidden_state = model.recurrent_inference( value, reward, policy_logits, hidden_state = model.recurrent_inference(
parent.hidden_state, parent.hidden_state,
torch.tensor([[last_action]]).to(parent.hidden_state.device), torch.tensor([last_action]).unsqueeze(1).to(parent.hidden_state.device),
) )
value = self.support_to_scalar(value, self.config.support_size)
reward = self.support_to_scalar(reward, self.config.support_size)
node.expand( node.expand(
self.config.action_space, self.config.action_space,
virtual_to_play, virtual_to_play,
...@@ -236,6 +249,29 @@ class MCTS: ...@@ -236,6 +249,29 @@ class MCTS:
value = node.reward + self.config.discount * value value = node.reward + self.config.discount * value
@staticmethod
def support_to_scalar(logits, support_size):
"""
Transform a categorical representation to a scalar
See paper appendix Network Architecture
"""
# Decode to a scalar
probs = torch.softmax(logits, dim=1)
support = (
torch.tensor([x for x in range(-support_size, support_size + 1)])
.expand(probs.shape)
.to(device=probs.device)
)
x = torch.sum(support * probs, dim=1, keepdim=True)
# Invert the scaling (defined in https://arxiv.org/abs/1805.11593)
x = torch.sign(x) * (
((torch.sqrt(1 + 4 * 0.001 * (torch.abs(x) + 1 + 0.001)) - 1) / (2 * 0.001))
** 2
- 1
)
return x
class Node: class Node:
def __init__(self, prior): def __init__(self, prior):
...@@ -294,7 +330,6 @@ class GameHistory: ...@@ -294,7 +330,6 @@ class GameHistory:
def store_search_statistics(self, root, action_space): def store_search_statistics(self, root, action_space):
sum_visits = sum(child.visit_count for child in root.children.values()) sum_visits = sum(child.visit_count for child in root.children.values())
# TODO: action could be of any type, not only integers
self.child_visits.append( self.child_visits.append(
[ [
root.children[a].visit_count / sum_visits if a in root.children else 0 root.children[a].visit_count / sum_visits if a in root.children else 0
......
...@@ -23,7 +23,8 @@ class Trainer: ...@@ -23,7 +23,8 @@ class Trainer:
self.config.observation_shape, self.config.observation_shape,
len(self.config.action_space), len(self.config.action_space),
self.config.encoding_size, self.config.encoding_size,
self.config.hidden_size, self.config.hidden_layers,
self.config.support_size
) )
self.model.set_weights(initial_weights) self.model.set_weights(initial_weights)
self.model.to(torch.device(config.training_device)) self.model.to(torch.device(config.training_device))
...@@ -81,6 +82,9 @@ class Trainer: ...@@ -81,6 +82,9 @@ class Trainer:
target_reward = torch.tensor(target_reward).float().to(device) target_reward = torch.tensor(target_reward).float().to(device)
target_policy = torch.tensor(target_policy).float().to(device) target_policy = torch.tensor(target_policy).float().to(device)
target_value = self.scalar_to_support(target_value, self.config.support_size)
target_reward = self.scalar_to_support(target_reward, self.config.support_size)
value, reward, policy_logits, hidden_state = self.model.initial_inference( value, reward, policy_logits, hidden_state = self.model.initial_inference(
observation_batch observation_batch
) )
...@@ -99,13 +103,13 @@ class Trainer: ...@@ -99,13 +103,13 @@ class Trainer:
current_value_loss, current_value_loss,
current_reward_loss, current_reward_loss,
current_policy_loss, current_policy_loss,
) = loss_function( ) = self.loss_function(
value.squeeze(-1), value.squeeze(-1),
reward.squeeze(-1), reward.squeeze(-1),
policy_logits, policy_logits,
target_value[:, i], target_value[:, i],
target_reward[:, i], target_reward[:, i],
target_policy[:, i, :], target_policy[:, i],
) )
value_loss += current_value_loss value_loss += current_value_loss
reward_loss += current_reward_loss reward_loss += current_reward_loss
...@@ -139,14 +143,35 @@ class Trainer: ...@@ -139,14 +143,35 @@ class Trainer:
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
param_group["lr"] = lr param_group["lr"] = lr
@staticmethod
def scalar_to_support(x, support_size):
"""
Transform a scalar to a categorical representation with (2 * support_size + 1) categories
See paper appendix Network Architecture
"""
# Reduce the scale (defined in https://arxiv.org/abs/1805.11593)
x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1 + 0.001 * x)
# Encode on a vector
x = torch.clamp(x, -support_size, support_size)
floor = x.floor()
ceil = x.ceil()
prob = x - floor
logits = torch.zeros(x.shape[0], x.shape[1], 2 * support_size + 1).to(x.device)
logits.scatter_(
2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1)
)
logits.scatter_(
2, (ceil + support_size).long().unsqueeze(-1), prob.unsqueeze(-1)
)
return logits
def loss_function( @staticmethod
def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy value, reward, policy_logits, target_value, target_reward, target_policy
): ):
# TODO: paper promotes cross entropy instead of MSE # Cross-entropy had a better convergence than MSE
value_loss = torch.nn.MSELoss()(value, target_value) value_loss = (-target_value * torch.nn.LogSoftmax(dim=1)(value)).sum(1).mean()
reward_loss = torch.nn.MSELoss()(reward, target_reward) reward_loss = (-target_reward * torch.nn.LogSoftmax(dim=1)(reward)).sum(1).mean()
policy_loss = torch.mean( policy_loss = (-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum(1).mean()
torch.sum(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits), 1)
)
return value_loss, reward_loss, policy_loss return value_loss, reward_loss, policy_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment