diff --git a/agents/saved_models/A0 b/agents/saved_models/A0 index c61c8f1..d1dbc28 100644 Binary files a/agents/saved_models/A0 and b/agents/saved_models/A0 differ diff --git a/agents/saved_models/C0 b/agents/saved_models/C0 index eb5f400..42f68bb 100644 Binary files a/agents/saved_models/C0 and b/agents/saved_models/C0 differ diff --git a/entities/player.py b/entities/player.py index ca1d59e..60f8df8 100644 --- a/entities/player.py +++ b/entities/player.py @@ -62,6 +62,7 @@ class Player(pygame.sprite.Sprite): chkpt_dir, no_load=False): + self.max_num_enemies = len(self.distance_direction_from_enemy) self.get_current_state() self.num_features = len(self.state_features) @@ -158,9 +159,17 @@ class Player(pygame.sprite.Sprite): self.reward_features = [ self.stats.exp, + 2*np.exp(-nearest_dist**2), - np.exp(-nearest_enemy.stats.health), - -np.exp(-self.stats.health**2) + + 1/(np.exp((nearest_enemy.stats.health - + nearest_enemy.stats.monster_info['health'])/nearest_enemy.stats.monster_info['health'])) - 1, + + 1/(np.exp((len(self.distance_direction_from_enemy) - + self.max_num_enemies)/self.max_num_enemies)) - 1, + + 1 - 1/(np.exp((self.stats.health - + self.stats.stats['health'])/self.stats.stats['health'])) if not self.is_dead() > 0 else -1 ] @@ -233,7 +242,7 @@ class Player(pygame.sprite.Sprite): self.stats.energy_recovery() else: - self.stats.exp = max(0, self.stats.exp - .01) + self.stats.exp = max(-1, self.stats.exp - .1) # Refresh player based on input and animate self.get_status() diff --git a/pneuma.py b/pneuma.py index 9b6af0c..2a30b64 100644 --- a/pneuma.py +++ b/pneuma.py @@ -1,4 +1,4 @@ -import os +import random import argparse import torch as T import numpy as np @@ -55,7 +55,7 @@ if __name__ == "__main__": parser.add_argument('--horizon', type=int, - default=200, + default=2048, help="The number of steps per update") parser.add_argument('--show_pg', @@ -100,6 +100,7 @@ if __name__ == "__main__": args = parser.parse_args() + random.seed(args.seed) np.random.seed(args.seed) T.manual_seed(args.seed) @@ -219,12 +220,6 @@ if __name__ == "__main__": print(f"Models saved to {chkpt_path}") - else: - print(f"\nScore this round for player {player.player_id}:\ - {player.stats.exp}") - - - # End of training session print("End of episodes.\ \nExiting game...") @@ -264,4 +259,5 @@ if __name__ == "__main__": for total in total_loss: plt.plot(total) plt.savefig(f"{figure_folder}/total_loss.png") + game.quit()