Updated adam to reduce plasticity and made the agent more observant, amongst other things

This commit is contained in:
Vasilis Valatsos 2024-02-29 19:07:31 +02:00
parent 7eb7228a8c
commit 78c298e072
5 changed files with 46 additions and 41 deletions

2
.gitignore vendored
View file

@ -164,3 +164,5 @@ cython_debug/
# Random stuff
__pycache__/
chkpts/
figures/

View file

@ -86,8 +86,8 @@ class Player(pygame.sprite.Sprite):
print("Attempting to load models ...")
try:
self.agent.load_models(
actr_chkpt=f"run{load}/A{self.player_id}",
crtc_chkpt=f"run{load}/C{self.player_id}"
actr_chkpt=f"{chkpt_dir}/../run{load}/A{self.player_id}",
crtc_chkpt=f"{chkpt_dir}/../run{load}/C{self.player_id}"
)
print("Models loaded ...\n")
@ -160,9 +160,9 @@ class Player(pygame.sprite.Sprite):
if hasattr(self, 'state_features'):
self.old_state_features = self.state_features
self.reward = self.stats.exp\
+ self.stats.health/self.stats.stats['health'] - 1\
- nearest_dist/np.sqrt(np.sum(self.map_edge))
self.reward = self.stats.exp
# + self.stats.health/self.stats.stats['health'] - 1\
# - nearest_dist/np.sqrt(np.sum(self.map_edge))
self.state_features = [
self.animation.rect.center[0]/self.map_edge[0],
@ -173,28 +173,28 @@ class Player(pygame.sprite.Sprite):
self.stats.energy/self.stats.stats['energy'],
]
self.state_features.extend([
nearest_dist/np.sqrt(np.sum(self.map_edge)),
nearest_en_dir[0],
nearest_en_dir[1],
nearest_enemy.stats.exp
])
# for distance, direction, enemy in self.distance_direction_from_enemy:
# self.state_features.extend([
#
# distance/np.sqrt(np.sum(self.map_edge)),
#
# direction[0],
#
# direction[1],
#
# enemy.stats.health /
# enemy.stats.monster_info['health'],
#
# enemy.stats.exp,
# nearest_dist/np.sqrt(np.sum(self.map_edge)),
# nearest_en_dir[0],
# nearest_en_dir[1],
# nearest_enemy.stats.exp
# ])
for distance, direction, enemy in sorted_distances[:5]:
self.state_features.extend([
distance/np.sqrt(np.sum(self.map_edge)),
direction[0],
direction[1],
enemy.stats.health /
enemy.stats.monster_info['health'],
enemy.stats.exp,
])
if hasattr(self, 'num_features'):
while len(self.state_features) < self.num_features:
self.state_features.append(0)

BIN
figures/.DS_Store vendored

Binary file not shown.

View file

@ -52,7 +52,7 @@ class PPOMemory:
class ActorNetwork(nn.Module):
def __init__(self, input_dim, output_dim, alpha, fc1_dims=1024, fc2_dims=1024, chkpt_dir='tmp/ppo'):
def __init__(self, input_dim, output_dim, alpha, fc1_dims=512, fc2_dims=512, chkpt_dir='tmp'):
super(ActorNetwork, self).__init__()
self.chkpt_dir = chkpt_dir
@ -68,7 +68,7 @@ class ActorNetwork(nn.Module):
nn.Softmax(dim=-1)
)
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5)
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
@ -80,18 +80,19 @@ class ActorNetwork(nn.Module):
return dist
def save_checkpoint(self, filename='actor_torch_ppo'):
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
def save_checkpoint(self, filename):
T.save(self.state_dict(), os.path.join(filename))
def load_checkpoint(self, filename='actor_torch_ppo'):
def load_checkpoint(self, filename):
print(filename)
self.load_state_dict(
T.load(os.path.join(self.chkpt_dir, filename),
T.load(os.path.join(filename),
map_location=self.device))
class CriticNetwork(nn.Module):
def __init__(self, input_dims, alpha, fc1_dims=4096, fc2_dims=4096, chkpt_dir='tmp/ppo'):
def __init__(self, input_dims, alpha, fc1_dims=2048, fc2_dims=2048, chkpt_dir='tmp'):
super(CriticNetwork, self).__init__()
self.chkpt_dir = chkpt_dir
@ -105,16 +106,16 @@ class CriticNetwork(nn.Module):
nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(),
# nn.Linear(fc1_dims, fc2_dims),
# nn.LeakyReLU(),
# nn.Linear(fc1_dims, fc2_dims),
# nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(),
nn.Linear(fc2_dims, 1)
)
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5)
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
@ -124,10 +125,11 @@ class CriticNetwork(nn.Module):
value = self.critic(state)
return value
def save_checkpoint(self, filename='critic_torch_ppo'):
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
def save_checkpoint(self, filename):
T.save(self.state_dict(), os.path.join(filename))
def load_checkpoint(self, filename='critic_torch_ppo'):
def load_checkpoint(self, filename):
print(filename)
self.load_state_dict(
T.load(os.path.join(self.chkpt_dir, filename),
T.load(os.path.join(filename),
map_location=self.device))

View file

@ -25,6 +25,7 @@ def plot_avg_time(time_steps, num_players, fig_path):
for player in time_steps:
plt.plot(player)
plt.savefig(os.path.join(fig_path, 'avg_time.png'))
plt.close()
def plot_score(scores, num_players, figure_path):