diff --git a/agents/ppo/brain.py b/agents/ppo/brain.py index 3ab73b3..3c97910 100644 --- a/agents/ppo/brain.py +++ b/agents/ppo/brain.py @@ -82,7 +82,7 @@ class ActorNetwork(nn.Module): T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename)) def load_checkpoint(self, filename = 'actor_torch_ppo'): - self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename))) + self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename), map_location=self.device)) class CriticNetwork(nn.Module): @@ -114,4 +114,4 @@ class CriticNetwork(nn.Module): T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename)) def load_checkpoint(self, filename = 'critic_torch_ppo'): - self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename))) + self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename), map_location=self.device)) diff --git a/main.py b/main.py index 94312f2..d0059f8 100644 --- a/main.py +++ b/main.py @@ -14,8 +14,8 @@ random.seed(1) np.random.seed(1) T.manual_seed(1) -n_episodes = 1000 -game_len = 10000 +n_episodes = 10000 +game_len = 20000 figure_file = 'plots/score.png' diff --git a/tmp/ppo/player_0_actor b/tmp/ppo/player_0_actor new file mode 100755 index 0000000..3d47cf7 Binary files /dev/null and b/tmp/ppo/player_0_actor differ diff --git a/tmp/ppo/player_0_critic b/tmp/ppo/player_0_critic new file mode 100755 index 0000000..f11f82b Binary files /dev/null and b/tmp/ppo/player_0_critic differ