Skip to content

Commit

Permalink
Fix MCTS
Browse files Browse the repository at this point in the history
  • Loading branch information
werner-duvaud committed Mar 11, 2020
1 parent 283e353 commit 0918977
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 29 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# MuZero General

A commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v1/anc/pseudocode.py).
A commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) and the associated [pseudocode](https://arxiv.org/src/1911.08265v2/anc/pseudocode.py).
It is designed to be easily adaptable for every games or reinforcement learning environments (like [gym](https://github.com/openai/gym)). You only need to edit the [game file](https://github.com/werner-duvaud/muzero-general/tree/master/games) with the parameters and the game class. Please refer to the [documentation](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) and the [example](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py).

MuZero is a model based reinforcement learning algorithm, successor of AlphaZero. It learns to master games without knowing the rules. It only knows actions and then learn to play and master the game. It is at least more efficient than similar algorithms like [AlphaZero](https://arxiv.org/abs/1712.01815), [SimPLe](https://arxiv.org/abs/1903.00374) and [World Models](https://arxiv.org/abs/1803.10122). See [How it works](https://github.com/werner-duvaud/muzero-general/wiki/How-MuZero-works)
Expand Down
20 changes: 10 additions & 10 deletions games/lunarlander.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def __init__(self):


### Self-Play
self.num_actors = 4 # Number of simultaneous threads self-playing to feed the replay buffer
self.num_actors = 10 # Number of simultaneous threads self-playing to feed the replay buffer
self.max_moves = 2000 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of futur moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
self.temperature_threshold = 2000 # Number of moves before dropping temperature to 0 (ie playing according to the max)
self.temperature_threshold = 600 # Number of moves before dropping temperature to 0 (ie playing according to the max)
self.self_play_delay = 0 # Number of seconds to wait after each played game to adjust the self play / training ratio to avoid over/underfitting

# Root prior exploration noise
Expand All @@ -51,30 +51,30 @@ def __init__(self):
self.resnet_fc_policy_layers = [] # Define the hidden layers in the policy head of the prediction network

# Fully Connected Network
self.encoding_size = 16
self.fc_reward_layers = [16] # Define the hidden layers in the reward network
self.fc_value_layers = [] # Define the hidden layers in the value network
self.fc_policy_layers = [] # Define the hidden layers in the policy network
self.encoding_size = 10
self.fc_reward_layers = [8] # Define the hidden layers in the reward network
self.fc_value_layers = [8] # Define the hidden layers in the value network
self.fc_policy_layers = [8] # Define the hidden layers in the policy network
self.fc_representation_layers = [] # Define the hidden layers in the representation network
self.fc_dynamics_layers = [16] # Define the hidden layers in the dynamics network
self.fc_dynamics_layers = [8] # Define the hidden layers in the dynamics network


### Training
self.results_path = os.path.join(os.path.dirname(__file__), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.training_steps = 50000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.num_unroll_steps = 10 # Number of game moves to keep for every batch element
self.num_unroll_steps = 200 # Number of game moves to keep for every batch element
self.checkpoint_interval = 10 # Number of training steps before using the model for sef-playing
self.window_size = 1000 # Number of self-play games to keep in the replay buffer
self.td_steps = 2000 # Number of steps in the futur to take into account for calculating the target value
self.td_steps = 600 # Number of steps in the futur to take into account for calculating the target value
self.training_delay = 0 # Number of seconds to wait after each training to adjust the self play / training ratio to avoid over/underfitting
self.training_device = "cuda" if torch.cuda.is_available() else "cpu" # Train on GPU if available

self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9

# Exponential learning rate schedule
self.lr_init = 0.05 # Initial learning rate
self.lr_init = 0.5 # Initial learning rate
self.lr_decay_rate = 1 # Set it to 1 to use a constant learning rate
self.lr_decay_steps = 1000

Expand Down
15 changes: 7 additions & 8 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def __init__(
)

def prediction(self, encoded_state):
policy_logit = self.prediction_policy_network(encoded_state)
policy_logits = self.prediction_policy_network(encoded_state)
value = self.prediction_value_network(encoded_state)
return policy_logit, value
return policy_logits, value

def representation(self, observation):
encoded_state = self.representation_network(
Expand Down Expand Up @@ -131,20 +131,20 @@ def dynamics(self, encoded_state, action):

def initial_inference(self, observation):
encoded_state = self.representation(observation)
policy_logit, value = self.prediction(encoded_state)
policy_logits, value = self.prediction(encoded_state)
return (
value,
torch.zeros(len(observation), self.full_support_size).to(
observation.device
),
policy_logit,
policy_logits,
encoded_state,
)

def recurrent_inference(self, encoded_state, action):
next_encoded_state, reward = self.dynamics(encoded_state, action)
policy_logit, value = self.prediction(next_encoded_state)
return value, reward, policy_logit, next_encoded_state
policy_logits, value = self.prediction(next_encoded_state)
return value, reward, policy_logits, next_encoded_state

def get_weights(self):
return {key: value.cpu() for key, value in self.state_dict().items()}
Expand Down Expand Up @@ -460,8 +460,7 @@ def set_weights(self, weights):
class FullyConnectedNetwork(torch.nn.Module):
def __init__(self, input_size, layer_sizes, output_size, activation=None):
super(FullyConnectedNetwork, self).__init__()
sizes_list = layer_sizes.copy()
sizes_list.insert(0, input_size)
sizes_list = [input_size] + layer_sizes
layers = []
if 1 < len(sizes_list):
for i in range(len(sizes_list) - 1):
Expand Down
4 changes: 2 additions & 2 deletions muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, game_name):
)
raise err

# Fix random generator seed for reproductibility
# Fix random generator seed
numpy.random.seed(self.config.seed)
torch.manual_seed(self.config.seed)

Expand Down Expand Up @@ -258,6 +258,6 @@ def load_model(self, path=None):
## Successive training, create a new config file for each experiment
# experiments = ["cartpole", "tictactoe"]
# for experiment in experiments:
# print("\nStrating experiment {}".format(experiment))
# print("\nStarting experiment {}".format(experiment))
# muzero = MuZero(experiment)
# muzero.train()
18 changes: 11 additions & 7 deletions self_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def select_action(node, temperature):
return action


# Game independant
# Game independent
class MCTS:
"""
Core Monte Carlo Tree Search algorithm.
Expand Down Expand Up @@ -311,7 +311,11 @@ def ucb_score(self, parent, child, min_max_stats):
pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

prior_score = pb_c * child.prior
value_score = min_max_stats.normalize(child.value())

if child.visit_count > 0:
value_score = child.reward + self.config.discount * min_max_stats.normalize(child.value())
else:
value_score = 0

return prior_score + value_score

Expand All @@ -320,7 +324,7 @@ def backpropagate(self, search_path, value, to_play, min_max_stats):
At the end of a simulation, we propagate the evaluation all the way up the tree
to the root.
"""
for node in search_path:
for node in reversed(search_path):
node.value_sum += value if node.to_play == to_play else -value
node.visit_count += 1
min_max_stats.update(node.value())
Expand All @@ -334,14 +338,14 @@ def support_to_scalar(logits, support_size):
See paper appendix Network Architecture
"""
# Decode to a scalar
probs = torch.softmax(logits, dim=1)
probabilities = torch.softmax(logits, dim=1)
support = (
torch.tensor([x for x in range(-support_size, support_size + 1)])
.expand(probs.shape)
.expand(probabilities.shape)
.float()
.to(device=probs.device)
.to(device=probabilities.device)
)
x = torch.sum(support * probs, dim=1, keepdim=True)
x = torch.sum(support * probabilities, dim=1, keepdim=True)

# Invert the scaling (defined in https://arxiv.org/abs/1805.11593)
x = torch.sign(x) * (
Expand Down
2 changes: 1 addition & 1 deletion trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def update_weights(self, batch):
target_value = torch.tensor(target_value).float().to(device)
target_reward = torch.tensor(target_reward).float().to(device)
target_policy = torch.tensor(target_policy).float().to(device)
# observation_batch: batch, channels, heigth, width
# observation_batch: batch, channels, height, width
# action_batch: batch, num_unroll_steps+1, 1 (unsqueeze)
# target_value: batch, num_unroll_steps+1
# target_reward: batch, num_unroll_steps+1
Expand Down

0 comments on commit 0918977

Please sign in to comment.