Skip to content

Commit

Permalink
Add resume training and improve training exit
Browse files Browse the repository at this point in the history
Co-authored-by: ahainaut <59559236+ahainaut@users.noreply.github.com>
  • Loading branch information
werner-duvaud and ahainaut committed Aug 20, 2020
1 parent a25b9c0 commit de80a8b
Show file tree
Hide file tree
Showing 17 changed files with 331 additions and 303 deletions.
5 changes: 2 additions & 3 deletions diagnose_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,12 @@ class DiagnoseModel:
config: configuration class instance related to the weights.
"""

def __init__(self, weights, config):
def __init__(self, checkpoint, config):
self.config = config

# Initialize the network
self.model = models.MuZeroNetwork(self.config)
self.model.set_weights(weights)
self.model.to(torch.device("cpu"))
self.model.set_weights(checkpoint["weights"])
self.model.eval()

def get_virtual_trajectory_from_obs(
Expand Down
2 changes: 1 addition & 1 deletion games/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self):

### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_weights = True # Save the weights in results_path as model.weights
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = int(1000e3) # Total number of training steps (ie weights update according to a batch)
self.batch_size = 1024 # Number of parts of games to train on at each training step
self.checkpoint_interval = int(1e3) # Number of training steps before using the model for self-playing
Expand Down
2 changes: 1 addition & 1 deletion games/breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self):

### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_weights = True # Save the weights in results_path as model.weights
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = int(1000e3) # Total number of training steps (ie weights update according to a batch)
self.batch_size = 16 # Number of parts of games to train on at each training step
self.checkpoint_interval = 500 # Number of training steps before using the model for self-playing
Expand Down
2 changes: 1 addition & 1 deletion games/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self):

### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_weights = True # Save the weights in results_path as model.weights
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
Expand Down
3 changes: 1 addition & 2 deletions games/connect4.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import os

import gym
import numpy
import torch

Expand Down Expand Up @@ -74,7 +73,7 @@ def __init__(self):

### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_weights = True # Save the weights in results_path as model.weights
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 100000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 64 # Number of parts of games to train on at each training step
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
Expand Down
3 changes: 1 addition & 2 deletions games/gomoku.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import math
import os

import gym
import numpy
import torch

Expand Down Expand Up @@ -75,7 +74,7 @@ def __init__(self):

### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_weights = True # Save the weights in results_path as model.weights
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 512 # Number of parts of games to train on at each training step
self.checkpoint_interval = 50 # Number of training steps before using the model for self-playing
Expand Down
2 changes: 1 addition & 1 deletion games/gridworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self):

### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_weights = True # Save the weights in results_path as model.weights
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 30000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
Expand Down
2 changes: 1 addition & 1 deletion games/lunarlander.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self):

### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_weights = True # Save the weights in results_path as model.weights
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 200000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 64 # Number of parts of games to train on at each training step
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
Expand Down
3 changes: 1 addition & 2 deletions games/simple_grid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import os

import gym
import numpy
import torch

Expand Down Expand Up @@ -74,7 +73,7 @@ def __init__(self):

### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_weights = True # Save the weights in results_path as model.weights
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 30000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 32 # Number of parts of games to train on at each training step
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
Expand Down
3 changes: 1 addition & 2 deletions games/tictactoe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import os

import gym
import numpy
import torch

Expand Down Expand Up @@ -74,7 +73,7 @@ def __init__(self):

### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_weights = True # Save the weights in results_path as model.weights
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 1000000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 64 # Number of parts of games to train on at each training step
self.checkpoint_interval = 100 # Number of training steps before using the model for self-playing
Expand Down
8 changes: 3 additions & 5 deletions games/twentyone.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

import datetime
import os
from random import randint

import gym
import numpy
import torch

Expand Down Expand Up @@ -82,7 +80,7 @@ def __init__(self):

### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_weights = True # Save the weights in results_path as model.weights
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 15000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 64 # Number of parts of games to train on at each training step
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
Expand Down Expand Up @@ -228,7 +226,7 @@ def action_to_string(self, action_number):

class TwentyOne:
def __init__(self, seed):
numpy.random.seed(seed)
self.random = numpy.random.RandomState(seed)

self.player_hand = self.deal_card_value()
self.dealer_hand = self.deal_card_value()
Expand Down Expand Up @@ -287,7 +285,7 @@ def get_reward(self, done):
return -1

def deal_card_value(self):
card = randint(1, 13)
card = self.random.randint(1, 13)
if card >= 10:
value = 10
else:
Expand Down
14 changes: 13 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def __new__(cls, config):
)


def dict_to_cpu(dictionary):
cpu_dict = {}
for key, value in dictionary.items():
if isinstance(value, torch.Tensor):
cpu_dict[key] = value.cpu()
elif isinstance(value, dict):
cpu_dict[key] = dict_to_cpu(value)
else:
cpu_dict[key] = value
return cpu_dict


class AbstractNetwork(ABC, torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -55,7 +67,7 @@ def recurrent_inference(self, encoded_state, action):
pass

def get_weights(self):
return {key: value.cpu() for key, value in self.state_dict().items()}
return dict_to_cpu(self.state_dict())

def set_weights(self, weights):
self.load_state_dict(weights)
Expand Down
Loading

0 comments on commit de80a8b

Please sign in to comment.