Updated adam to reduce plasticity and made the agent more observant, amongst other things

This commit is contained in:
Vasilis Valatsos 2024-02-29 19:07:31 +02:00
parent 7eb7228a8c
commit 78c298e072
5 changed files with 46 additions and 41 deletions

2
.gitignore vendored
View file

@ -164,3 +164,5 @@ cython_debug/
# Random stuff # Random stuff
__pycache__/ __pycache__/
chkpts/
figures/

View file

@ -86,8 +86,8 @@ class Player(pygame.sprite.Sprite):
print("Attempting to load models ...") print("Attempting to load models ...")
try: try:
self.agent.load_models( self.agent.load_models(
actr_chkpt=f"run{load}/A{self.player_id}", actr_chkpt=f"{chkpt_dir}/../run{load}/A{self.player_id}",
crtc_chkpt=f"run{load}/C{self.player_id}" crtc_chkpt=f"{chkpt_dir}/../run{load}/C{self.player_id}"
) )
print("Models loaded ...\n") print("Models loaded ...\n")
@ -160,9 +160,9 @@ class Player(pygame.sprite.Sprite):
if hasattr(self, 'state_features'): if hasattr(self, 'state_features'):
self.old_state_features = self.state_features self.old_state_features = self.state_features
self.reward = self.stats.exp\ self.reward = self.stats.exp
+ self.stats.health/self.stats.stats['health'] - 1\ # + self.stats.health/self.stats.stats['health'] - 1\
- nearest_dist/np.sqrt(np.sum(self.map_edge)) # - nearest_dist/np.sqrt(np.sum(self.map_edge))
self.state_features = [ self.state_features = [
self.animation.rect.center[0]/self.map_edge[0], self.animation.rect.center[0]/self.map_edge[0],
@ -173,27 +173,27 @@ class Player(pygame.sprite.Sprite):
self.stats.energy/self.stats.stats['energy'], self.stats.energy/self.stats.stats['energy'],
] ]
self.state_features.extend([ # self.state_features.extend([
nearest_dist/np.sqrt(np.sum(self.map_edge)), # nearest_dist/np.sqrt(np.sum(self.map_edge)),
nearest_en_dir[0], # nearest_en_dir[0],
nearest_en_dir[1], # nearest_en_dir[1],
nearest_enemy.stats.exp # nearest_enemy.stats.exp
]) # ])
# for distance, direction, enemy in self.distance_direction_from_enemy: for distance, direction, enemy in sorted_distances[:5]:
# self.state_features.extend([ self.state_features.extend([
#
# distance/np.sqrt(np.sum(self.map_edge)), distance/np.sqrt(np.sum(self.map_edge)),
#
# direction[0], direction[0],
#
# direction[1], direction[1],
#
# enemy.stats.health / enemy.stats.health /
# enemy.stats.monster_info['health'], enemy.stats.monster_info['health'],
#
# enemy.stats.exp, enemy.stats.exp,
# ]) ])
if hasattr(self, 'num_features'): if hasattr(self, 'num_features'):
while len(self.state_features) < self.num_features: while len(self.state_features) < self.num_features:

BIN
figures/.DS_Store vendored

Binary file not shown.

View file

@ -52,7 +52,7 @@ class PPOMemory:
class ActorNetwork(nn.Module): class ActorNetwork(nn.Module):
def __init__(self, input_dim, output_dim, alpha, fc1_dims=1024, fc2_dims=1024, chkpt_dir='tmp/ppo'): def __init__(self, input_dim, output_dim, alpha, fc1_dims=512, fc2_dims=512, chkpt_dir='tmp'):
super(ActorNetwork, self).__init__() super(ActorNetwork, self).__init__()
self.chkpt_dir = chkpt_dir self.chkpt_dir = chkpt_dir
@ -68,7 +68,7 @@ class ActorNetwork(nn.Module):
nn.Softmax(dim=-1) nn.Softmax(dim=-1)
) )
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5) self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5)
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu') self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
@ -80,18 +80,19 @@ class ActorNetwork(nn.Module):
return dist return dist
def save_checkpoint(self, filename='actor_torch_ppo'): def save_checkpoint(self, filename):
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename)) T.save(self.state_dict(), os.path.join(filename))
def load_checkpoint(self, filename='actor_torch_ppo'): def load_checkpoint(self, filename):
print(filename)
self.load_state_dict( self.load_state_dict(
T.load(os.path.join(self.chkpt_dir, filename), T.load(os.path.join(filename),
map_location=self.device)) map_location=self.device))
class CriticNetwork(nn.Module): class CriticNetwork(nn.Module):
def __init__(self, input_dims, alpha, fc1_dims=4096, fc2_dims=4096, chkpt_dir='tmp/ppo'): def __init__(self, input_dims, alpha, fc1_dims=2048, fc2_dims=2048, chkpt_dir='tmp'):
super(CriticNetwork, self).__init__() super(CriticNetwork, self).__init__()
self.chkpt_dir = chkpt_dir self.chkpt_dir = chkpt_dir
@ -105,16 +106,16 @@ class CriticNetwork(nn.Module):
nn.LeakyReLU(), nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims), nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(), nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims), # nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(), # nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims), # nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(), # nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims), nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(), nn.LeakyReLU(),
nn.Linear(fc2_dims, 1) nn.Linear(fc2_dims, 1)
) )
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5) self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5)
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu') self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
@ -124,10 +125,11 @@ class CriticNetwork(nn.Module):
value = self.critic(state) value = self.critic(state)
return value return value
def save_checkpoint(self, filename='critic_torch_ppo'): def save_checkpoint(self, filename):
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename)) T.save(self.state_dict(), os.path.join(filename))
def load_checkpoint(self, filename='critic_torch_ppo'): def load_checkpoint(self, filename):
print(filename)
self.load_state_dict( self.load_state_dict(
T.load(os.path.join(self.chkpt_dir, filename), T.load(os.path.join(filename),
map_location=self.device)) map_location=self.device))

View file

@ -25,6 +25,7 @@ def plot_avg_time(time_steps, num_players, fig_path):
for player in time_steps: for player in time_steps:
plt.plot(player) plt.plot(player)
plt.savefig(os.path.join(fig_path, 'avg_time.png')) plt.savefig(os.path.join(fig_path, 'avg_time.png'))
plt.close()
def plot_score(scores, num_players, figure_path): def plot_score(scores, num_players, figure_path):