Weird bug showed, added diagnostics

This commit is contained in:
Vasilis Valatsos 2023-11-24 21:15:50 +01:00
parent d0098af801
commit a9868e6c1a
7 changed files with 13 additions and 8 deletions

Binary file not shown.

View file

@ -22,17 +22,18 @@ class Agent:
def remember(self, state, action, probs, vals, reward, done): def remember(self, state, action, probs, vals, reward, done):
self.memory.store_memory(state, action, probs, vals, reward, done) self.memory.store_memory(state, action, probs, vals, reward, done)
def save_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'): def save_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'):
self.actor.save_checkpoint(actr_chkpt) self.actor.save_checkpoint(actr_chkpt)
self.critic.save_checkpoint(crtc_chkpt) self.critic.save_checkpoint(crtc_chkpt)
def load_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'): def load_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'):
self.actor.load_checkpoint(actr_chkpt) self.actor.load_checkpoint(actr_chkpt)
self.critic.load_checkpoint(crtc_chkpt) self.critic.load_checkpoint(crtc_chkpt)
def choose_action(self, observation): def choose_action(self, observation):
print(f"observation: {observation}")
state = T.tensor(observation, dtype=T.float).to(self.actor.device) state = T.tensor(observation, dtype=T.float).to(self.actor.device)
print(f"state: {state}")
dist = self.actor(state) dist = self.actor(state)
value = self.critic(state) value = self.critic(state)
action = dist.sample() action = dist.sample()

View file

@ -158,16 +158,20 @@ class Player(pygame.sprite.Sprite):
self.num_features = len(self.state_features) self.num_features = len(self.state_features)
def setup_agent(self): def setup_agent(self):
print(f"Initializing Agent {self.player_id} ...") print(f"Initializing agent on player {self.player_id} ...")
self.agent = Agent( self.agent = Agent(
input_dims=len(self.state_features), input_dims=len(self.state_features),
n_actions=len(self._input.possible_actions), n_actions=len(self._input.possible_actions),
batch_size=5, batch_size=5,
n_epochs=4) n_epochs=4)
print(f" Agent initialized using {self.agent.actor.device}. Attempting to load models ...") print(
f" Agent initialized using {self.agent.actor.device}. Attempting to load models ...")
try: try:
self.agent.load_models(actr_chkpt = f"player_{self.player_id}_actor", crtc_chkpt = f"player_{self.player_id}_critic") self.agent.load_models(
actr_chkpt=f"player_{self.player_id}_actor", crtc_chkpt=f"player_{self.player_id}_critic")
print("Models loaded ...\n") print("Models loaded ...\n")
except FileNotFoundError: except FileNotFoundError:
print("FileNotFound for agent. Skipping loading...\n") print("FileNotFound for agent. Skipping loading...\n")

BIN
plots/score_sp.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View file

@ -14,7 +14,7 @@ random.seed(1)
np.random.seed(1) np.random.seed(1)
T.manual_seed(1) T.manual_seed(1)
n_episodes = 2000 n_episodes = 1000
game_len = 5000 game_len = 5000
figure_file = 'plots/score_sp.png' figure_file = 'plots/score_sp.png'
@ -37,7 +37,7 @@ for i in tqdm(range(n_episodes)):
player.stats.exp = score_history[player.player_id][i-1] player.stats.exp = score_history[player.player_id][i-1]
player.agent = agent player.agent = agent
for j in range(game_len): for j in tqdm(range(game_len)):
if not game.level.done: if not game.level.done:
game.run() game.run()

BIN
tmp/ppo/player_actor Normal file

Binary file not shown.

BIN
tmp/ppo/player_critic Normal file

Binary file not shown.