diff --git a/.single-agent.py.kate-swp b/.single-agent.py.kate-swp deleted file mode 100644 index 9939bf4..0000000 Binary files a/.single-agent.py.kate-swp and /dev/null differ diff --git a/agents/ppo/agent.py b/agents/ppo/agent.py index 06899ea..6d60518 100644 --- a/agents/ppo/agent.py +++ b/agents/ppo/agent.py @@ -22,17 +22,18 @@ class Agent: def remember(self, state, action, probs, vals, reward, done): self.memory.store_memory(state, action, probs, vals, reward, done) - def save_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'): + def save_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'): self.actor.save_checkpoint(actr_chkpt) self.critic.save_checkpoint(crtc_chkpt) - def load_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'): + def load_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'): self.actor.load_checkpoint(actr_chkpt) self.critic.load_checkpoint(crtc_chkpt) def choose_action(self, observation): + print(f"observation: {observation}") state = T.tensor(observation, dtype=T.float).to(self.actor.device) - + print(f"state: {state}") dist = self.actor(state) value = self.critic(state) action = dist.sample() diff --git a/entities/player.py b/entities/player.py index 59c50ee..a9d601c 100644 --- a/entities/player.py +++ b/entities/player.py @@ -158,16 +158,20 @@ class Player(pygame.sprite.Sprite): self.num_features = len(self.state_features) def setup_agent(self): - print(f"Initializing Agent {self.player_id} ...") + print(f"Initializing agent on player {self.player_id} ...") self.agent = Agent( input_dims=len(self.state_features), n_actions=len(self._input.possible_actions), batch_size=5, n_epochs=4) - print(f" Agent initialized using {self.agent.actor.device}. Attempting to load models ...") + print( + f" Agent initialized using {self.agent.actor.device}. Attempting to load models ...") + try: - self.agent.load_models(actr_chkpt = f"player_{self.player_id}_actor", crtc_chkpt = f"player_{self.player_id}_critic") + self.agent.load_models( + actr_chkpt=f"player_{self.player_id}_actor", crtc_chkpt=f"player_{self.player_id}_critic") print("Models loaded ...\n") + except FileNotFoundError: print("FileNotFound for agent. Skipping loading...\n") diff --git a/plots/score_sp.png b/plots/score_sp.png new file mode 100644 index 0000000..c42dccc Binary files /dev/null and b/plots/score_sp.png differ diff --git a/single-agent.py b/single-agent.py index 0a74c84..6b7f1ef 100644 --- a/single-agent.py +++ b/single-agent.py @@ -14,7 +14,7 @@ random.seed(1) np.random.seed(1) T.manual_seed(1) -n_episodes = 2000 +n_episodes = 1000 game_len = 5000 figure_file = 'plots/score_sp.png' @@ -37,7 +37,7 @@ for i in tqdm(range(n_episodes)): player.stats.exp = score_history[player.player_id][i-1] player.agent = agent - for j in range(game_len): + for j in tqdm(range(game_len)): if not game.level.done: game.run() diff --git a/tmp/ppo/player_actor b/tmp/ppo/player_actor new file mode 100644 index 0000000..ba41026 Binary files /dev/null and b/tmp/ppo/player_actor differ diff --git a/tmp/ppo/player_critic b/tmp/ppo/player_critic new file mode 100644 index 0000000..c6db4f8 Binary files /dev/null and b/tmp/ppo/player_critic differ