Training done

This commit is contained in:
Vasilis Valatsos 2023-11-24 14:23:12 +01:00
parent bb43e56f8c
commit 1f91ec9d5d
4 changed files with 4 additions and 4 deletions

View file

@ -82,7 +82,7 @@ class ActorNetwork(nn.Module):
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename)) T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
def load_checkpoint(self, filename = 'actor_torch_ppo'): def load_checkpoint(self, filename = 'actor_torch_ppo'):
self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename))) self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename), map_location=self.device))
class CriticNetwork(nn.Module): class CriticNetwork(nn.Module):
@ -114,4 +114,4 @@ class CriticNetwork(nn.Module):
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename)) T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
def load_checkpoint(self, filename = 'critic_torch_ppo'): def load_checkpoint(self, filename = 'critic_torch_ppo'):
self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename))) self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename), map_location=self.device))

View file

@ -14,8 +14,8 @@ random.seed(1)
np.random.seed(1) np.random.seed(1)
T.manual_seed(1) T.manual_seed(1)
n_episodes = 1000 n_episodes = 10000
game_len = 10000 game_len = 20000
figure_file = 'plots/score.png' figure_file = 'plots/score.png'

BIN
tmp/ppo/player_0_actor Executable file

Binary file not shown.

BIN
tmp/ppo/player_0_critic Executable file

Binary file not shown.