Weird bug showed, added diagnostics
This commit is contained in:
parent
d0098af801
commit
a9868e6c1a
7 changed files with 13 additions and 8 deletions
Binary file not shown.
|
@ -22,17 +22,18 @@ class Agent:
|
||||||
def remember(self, state, action, probs, vals, reward, done):
|
def remember(self, state, action, probs, vals, reward, done):
|
||||||
self.memory.store_memory(state, action, probs, vals, reward, done)
|
self.memory.store_memory(state, action, probs, vals, reward, done)
|
||||||
|
|
||||||
def save_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'):
|
def save_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'):
|
||||||
self.actor.save_checkpoint(actr_chkpt)
|
self.actor.save_checkpoint(actr_chkpt)
|
||||||
self.critic.save_checkpoint(crtc_chkpt)
|
self.critic.save_checkpoint(crtc_chkpt)
|
||||||
|
|
||||||
def load_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'):
|
def load_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'):
|
||||||
self.actor.load_checkpoint(actr_chkpt)
|
self.actor.load_checkpoint(actr_chkpt)
|
||||||
self.critic.load_checkpoint(crtc_chkpt)
|
self.critic.load_checkpoint(crtc_chkpt)
|
||||||
|
|
||||||
def choose_action(self, observation):
|
def choose_action(self, observation):
|
||||||
|
print(f"observation: {observation}")
|
||||||
state = T.tensor(observation, dtype=T.float).to(self.actor.device)
|
state = T.tensor(observation, dtype=T.float).to(self.actor.device)
|
||||||
|
print(f"state: {state}")
|
||||||
dist = self.actor(state)
|
dist = self.actor(state)
|
||||||
value = self.critic(state)
|
value = self.critic(state)
|
||||||
action = dist.sample()
|
action = dist.sample()
|
||||||
|
|
|
@ -158,16 +158,20 @@ class Player(pygame.sprite.Sprite):
|
||||||
self.num_features = len(self.state_features)
|
self.num_features = len(self.state_features)
|
||||||
|
|
||||||
def setup_agent(self):
|
def setup_agent(self):
|
||||||
print(f"Initializing Agent {self.player_id} ...")
|
print(f"Initializing agent on player {self.player_id} ...")
|
||||||
self.agent = Agent(
|
self.agent = Agent(
|
||||||
input_dims=len(self.state_features),
|
input_dims=len(self.state_features),
|
||||||
n_actions=len(self._input.possible_actions),
|
n_actions=len(self._input.possible_actions),
|
||||||
batch_size=5,
|
batch_size=5,
|
||||||
n_epochs=4)
|
n_epochs=4)
|
||||||
print(f" Agent initialized using {self.agent.actor.device}. Attempting to load models ...")
|
print(
|
||||||
|
f" Agent initialized using {self.agent.actor.device}. Attempting to load models ...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.agent.load_models(actr_chkpt = f"player_{self.player_id}_actor", crtc_chkpt = f"player_{self.player_id}_critic")
|
self.agent.load_models(
|
||||||
|
actr_chkpt=f"player_{self.player_id}_actor", crtc_chkpt=f"player_{self.player_id}_critic")
|
||||||
print("Models loaded ...\n")
|
print("Models loaded ...\n")
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print("FileNotFound for agent. Skipping loading...\n")
|
print("FileNotFound for agent. Skipping loading...\n")
|
||||||
|
|
||||||
|
|
BIN
plots/score_sp.png
Normal file
BIN
plots/score_sp.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 16 KiB |
|
@ -14,7 +14,7 @@ random.seed(1)
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
T.manual_seed(1)
|
T.manual_seed(1)
|
||||||
|
|
||||||
n_episodes = 2000
|
n_episodes = 1000
|
||||||
game_len = 5000
|
game_len = 5000
|
||||||
|
|
||||||
figure_file = 'plots/score_sp.png'
|
figure_file = 'plots/score_sp.png'
|
||||||
|
@ -37,7 +37,7 @@ for i in tqdm(range(n_episodes)):
|
||||||
player.stats.exp = score_history[player.player_id][i-1]
|
player.stats.exp = score_history[player.player_id][i-1]
|
||||||
player.agent = agent
|
player.agent = agent
|
||||||
|
|
||||||
for j in range(game_len):
|
for j in tqdm(range(game_len)):
|
||||||
if not game.level.done:
|
if not game.level.done:
|
||||||
|
|
||||||
game.run()
|
game.run()
|
||||||
|
|
BIN
tmp/ppo/player_actor
Normal file
BIN
tmp/ppo/player_actor
Normal file
Binary file not shown.
BIN
tmp/ppo/player_critic
Normal file
BIN
tmp/ppo/player_critic
Normal file
Binary file not shown.
Loading…
Reference in a new issue