From 78c298e072ababbcad32fb26c40385c1521ef9ad Mon Sep 17 00:00:00 2001 From: Vasilis Valatsos Date: Thu, 29 Feb 2024 19:07:31 +0200 Subject: [PATCH] Updated adam to reduce plasticity and made the agent more observant, amongst other things --- .gitignore | 2 ++ entities/player.py | 50 ++++++++++++++++++++++----------------------- figures/.DS_Store | Bin 8196 -> 6148 bytes ml/ppo/brain.py | 34 +++++++++++++++--------------- utils/metrics.py | 1 + 5 files changed, 46 insertions(+), 41 deletions(-) diff --git a/.gitignore b/.gitignore index c2b0f72..de7d615 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,5 @@ cython_debug/ # Random stuff __pycache__/ +chkpts/ +figures/ \ No newline at end of file diff --git a/entities/player.py b/entities/player.py index 0ee6379..0a38e3c 100644 --- a/entities/player.py +++ b/entities/player.py @@ -86,8 +86,8 @@ class Player(pygame.sprite.Sprite): print("Attempting to load models ...") try: self.agent.load_models( - actr_chkpt=f"run{load}/A{self.player_id}", - crtc_chkpt=f"run{load}/C{self.player_id}" + actr_chkpt=f"{chkpt_dir}/../run{load}/A{self.player_id}", + crtc_chkpt=f"{chkpt_dir}/../run{load}/C{self.player_id}" ) print("Models loaded ...\n") @@ -160,9 +160,9 @@ class Player(pygame.sprite.Sprite): if hasattr(self, 'state_features'): self.old_state_features = self.state_features - self.reward = self.stats.exp\ - + self.stats.health/self.stats.stats['health'] - 1\ - - nearest_dist/np.sqrt(np.sum(self.map_edge)) + self.reward = self.stats.exp + # + self.stats.health/self.stats.stats['health'] - 1\ + # - nearest_dist/np.sqrt(np.sum(self.map_edge)) self.state_features = [ self.animation.rect.center[0]/self.map_edge[0], @@ -173,27 +173,27 @@ class Player(pygame.sprite.Sprite): self.stats.energy/self.stats.stats['energy'], ] - self.state_features.extend([ - nearest_dist/np.sqrt(np.sum(self.map_edge)), - nearest_en_dir[0], - nearest_en_dir[1], - nearest_enemy.stats.exp - ]) + # self.state_features.extend([ + # nearest_dist/np.sqrt(np.sum(self.map_edge)), + # nearest_en_dir[0], + # nearest_en_dir[1], + # nearest_enemy.stats.exp + # ]) - # for distance, direction, enemy in self.distance_direction_from_enemy: - # self.state_features.extend([ - # - # distance/np.sqrt(np.sum(self.map_edge)), - # - # direction[0], - # - # direction[1], - # - # enemy.stats.health / - # enemy.stats.monster_info['health'], - # - # enemy.stats.exp, - # ]) + for distance, direction, enemy in sorted_distances[:5]: + self.state_features.extend([ + + distance/np.sqrt(np.sum(self.map_edge)), + + direction[0], + + direction[1], + + enemy.stats.health / + enemy.stats.monster_info['health'], + + enemy.stats.exp, + ]) if hasattr(self, 'num_features'): while len(self.state_features) < self.num_features: diff --git a/figures/.DS_Store b/figures/.DS_Store index 5cab4fb9670122911c93bb6be415a7639323f794..2424cb8aa497a6716f56b3f2c0af6a68f0eb0a9d 100644 GIT binary patch delta 283 zcmZp1XfcprU|?W$DortDU=RQ@Ie-{MGjUEV6q~50D8veq2Z^N=C+8&P=jSkPEL_Ft z3=(ExC}JpO$YU^sOCbpwVG}e?DlaZb%E?axD%>`iNvKd&th(CJNJqiguvSN*+T7Ga zN5Rs>M0| p%s?Ix2yg=lSCGFp7Jg@*%rD~z@;Jz6EFhW*;&ZUr=6Iet%mCZXIZprp literal 8196 zcmeHM&2G~`5S~o~iB(dfrl`FjS#XU?6Nx~@#e{I+iV++Dg*Y~;h2xE4ha93vxgmaT z+!1fUktg60;9WSu%42cs}o}GY$Z-?#yig>;XU(H^TZRHW^0e zX**>lb)A*&m0ElHgEDZZt2MCU-&{DXPD5&UDRu4>3nJO`uq~m#$h9g=ER}v*1 zn4|+!-^^5n!o;^D&lJOfSrT=t70?QtR)A{vEf~NMVyxSf_4__dAcWDqfDfI3J2(zg zi(N+C{~=5`+s@?np#>hYGkw7Al#fQCk9LliPbn`KF@`Dl7>AF&U+*C@uc7oc#ct%V z3nO7~Q(kt?@RActX_WF)!+Q~ciyG1$d5#f{V1|7UqSXEP{^v|*9!=Dp{On8uKZ*|y zzOm9q`P}(R#i(pnp4vxzZpVXo7Ig>H$F%o=^T0Xo51fa-H}BVX?(i@k_@Ory1mE>= z$dgBY=<<1&&q6m6Tw9(njH=PE?=BXtcBg6HyuNIji^Ci3rr9~{ESE<0($#CX?)P5= zi@g> z@Bi0H5*@2nVC@uO<(}Q^V%W#OyV43!t!?7=3^y_ww~{C**obl*QIz9|Cw~}X-^8KB au@tT(;t{le@gjf<8U6eh^6O?z75E8#B~}9f diff --git a/ml/ppo/brain.py b/ml/ppo/brain.py index 04e178e..9b95852 100644 --- a/ml/ppo/brain.py +++ b/ml/ppo/brain.py @@ -52,7 +52,7 @@ class PPOMemory: class ActorNetwork(nn.Module): - def __init__(self, input_dim, output_dim, alpha, fc1_dims=1024, fc2_dims=1024, chkpt_dir='tmp/ppo'): + def __init__(self, input_dim, output_dim, alpha, fc1_dims=512, fc2_dims=512, chkpt_dir='tmp'): super(ActorNetwork, self).__init__() self.chkpt_dir = chkpt_dir @@ -68,7 +68,7 @@ class ActorNetwork(nn.Module): nn.Softmax(dim=-1) ) - self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5) + self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5) self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu') @@ -80,18 +80,19 @@ class ActorNetwork(nn.Module): return dist - def save_checkpoint(self, filename='actor_torch_ppo'): - T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename)) + def save_checkpoint(self, filename): + T.save(self.state_dict(), os.path.join(filename)) - def load_checkpoint(self, filename='actor_torch_ppo'): + def load_checkpoint(self, filename): + print(filename) self.load_state_dict( - T.load(os.path.join(self.chkpt_dir, filename), + T.load(os.path.join(filename), map_location=self.device)) class CriticNetwork(nn.Module): - def __init__(self, input_dims, alpha, fc1_dims=4096, fc2_dims=4096, chkpt_dir='tmp/ppo'): + def __init__(self, input_dims, alpha, fc1_dims=2048, fc2_dims=2048, chkpt_dir='tmp'): super(CriticNetwork, self).__init__() self.chkpt_dir = chkpt_dir @@ -105,16 +106,16 @@ class CriticNetwork(nn.Module): nn.LeakyReLU(), nn.Linear(fc1_dims, fc2_dims), nn.LeakyReLU(), - nn.Linear(fc1_dims, fc2_dims), - nn.LeakyReLU(), - nn.Linear(fc1_dims, fc2_dims), - nn.LeakyReLU(), + # nn.Linear(fc1_dims, fc2_dims), + # nn.LeakyReLU(), + # nn.Linear(fc1_dims, fc2_dims), + # nn.LeakyReLU(), nn.Linear(fc1_dims, fc2_dims), nn.LeakyReLU(), nn.Linear(fc2_dims, 1) ) - self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5) + self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5) self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu') @@ -124,10 +125,11 @@ class CriticNetwork(nn.Module): value = self.critic(state) return value - def save_checkpoint(self, filename='critic_torch_ppo'): - T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename)) + def save_checkpoint(self, filename): + T.save(self.state_dict(), os.path.join(filename)) - def load_checkpoint(self, filename='critic_torch_ppo'): + def load_checkpoint(self, filename): + print(filename) self.load_state_dict( - T.load(os.path.join(self.chkpt_dir, filename), + T.load(os.path.join(filename), map_location=self.device)) diff --git a/utils/metrics.py b/utils/metrics.py index a7e339a..990dfd9 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -25,6 +25,7 @@ def plot_avg_time(time_steps, num_players, fig_path): for player in time_steps: plt.plot(player) plt.savefig(os.path.join(fig_path, 'avg_time.png')) + plt.close() def plot_score(scores, num_players, figure_path):