Weird bug showed, added diagnostics
This commit is contained in:
parent
d0098af801
commit
a9868e6c1a
7 changed files with 13 additions and 8 deletions
Binary file not shown.
|
@ -22,17 +22,18 @@ class Agent:
|
|||
def remember(self, state, action, probs, vals, reward, done):
|
||||
self.memory.store_memory(state, action, probs, vals, reward, done)
|
||||
|
||||
def save_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'):
|
||||
def save_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'):
|
||||
self.actor.save_checkpoint(actr_chkpt)
|
||||
self.critic.save_checkpoint(crtc_chkpt)
|
||||
|
||||
def load_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'):
|
||||
def load_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'):
|
||||
self.actor.load_checkpoint(actr_chkpt)
|
||||
self.critic.load_checkpoint(crtc_chkpt)
|
||||
|
||||
def choose_action(self, observation):
|
||||
print(f"observation: {observation}")
|
||||
state = T.tensor(observation, dtype=T.float).to(self.actor.device)
|
||||
|
||||
print(f"state: {state}")
|
||||
dist = self.actor(state)
|
||||
value = self.critic(state)
|
||||
action = dist.sample()
|
||||
|
|
|
@ -158,16 +158,20 @@ class Player(pygame.sprite.Sprite):
|
|||
self.num_features = len(self.state_features)
|
||||
|
||||
def setup_agent(self):
|
||||
print(f"Initializing Agent {self.player_id} ...")
|
||||
print(f"Initializing agent on player {self.player_id} ...")
|
||||
self.agent = Agent(
|
||||
input_dims=len(self.state_features),
|
||||
n_actions=len(self._input.possible_actions),
|
||||
batch_size=5,
|
||||
n_epochs=4)
|
||||
print(f" Agent initialized using {self.agent.actor.device}. Attempting to load models ...")
|
||||
print(
|
||||
f" Agent initialized using {self.agent.actor.device}. Attempting to load models ...")
|
||||
|
||||
try:
|
||||
self.agent.load_models(actr_chkpt = f"player_{self.player_id}_actor", crtc_chkpt = f"player_{self.player_id}_critic")
|
||||
self.agent.load_models(
|
||||
actr_chkpt=f"player_{self.player_id}_actor", crtc_chkpt=f"player_{self.player_id}_critic")
|
||||
print("Models loaded ...\n")
|
||||
|
||||
except FileNotFoundError:
|
||||
print("FileNotFound for agent. Skipping loading...\n")
|
||||
|
||||
|
|
BIN
plots/score_sp.png
Normal file
BIN
plots/score_sp.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 16 KiB |
|
@ -14,7 +14,7 @@ random.seed(1)
|
|||
np.random.seed(1)
|
||||
T.manual_seed(1)
|
||||
|
||||
n_episodes = 2000
|
||||
n_episodes = 1000
|
||||
game_len = 5000
|
||||
|
||||
figure_file = 'plots/score_sp.png'
|
||||
|
@ -37,7 +37,7 @@ for i in tqdm(range(n_episodes)):
|
|||
player.stats.exp = score_history[player.player_id][i-1]
|
||||
player.agent = agent
|
||||
|
||||
for j in range(game_len):
|
||||
for j in tqdm(range(game_len)):
|
||||
if not game.level.done:
|
||||
|
||||
game.run()
|
||||
|
|
BIN
tmp/ppo/player_actor
Normal file
BIN
tmp/ppo/player_actor
Normal file
Binary file not shown.
BIN
tmp/ppo/player_critic
Normal file
BIN
tmp/ppo/player_critic
Normal file
Binary file not shown.
Loading…
Reference in a new issue