diff --git a/args.py b/args.py index 64a9ad4..3e26d88 100644 --- a/args.py +++ b/args.py @@ -13,7 +13,7 @@ def parse_args(): parser.add_argument('--no_seed', default=False, action="store_true", - help="Set to True to run without a seed.") + help="Set to run without a seed.") parser.add_argument('--seed', type=int, @@ -83,6 +83,11 @@ def parse_args(): action="store_true", help="Set flag to disable learning. Useful for viewing trained agents interact in the environment.") + parser.add_argument('--load', + type=int, + default=None, + help="Run id to load within chkpt_path") + parser.add_argument('--show_pg', default=False, action="store_true", diff --git a/entities/player.py b/entities/player.py index 9902fea..0ee6379 100644 --- a/entities/player.py +++ b/entities/player.py @@ -61,7 +61,7 @@ class Player(pygame.sprite.Sprite): gae_lambda, chkpt_dir, entropy_coef, - no_load=False): + load=None): self.max_num_enemies = len(self.distance_direction_from_enemy) self.get_current_state() @@ -82,12 +82,12 @@ class Player(pygame.sprite.Sprite): print( f"\nAgent initialized on player {self.player_id} using {self.agent.actor.device}.") - if not no_load: + if load: print("Attempting to load models ...") try: self.agent.load_models( - actr_chkpt=f"A{self.player_id}", - crtc_chkpt=f"C{self.player_id}" + actr_chkpt=f"run{load}/A{self.player_id}", + crtc_chkpt=f"run{load}/C{self.player_id}" ) print("Models loaded ...\n") diff --git a/main.py b/main.py index 2de0006..31fa95d 100644 --- a/main.py +++ b/main.py @@ -80,7 +80,7 @@ def main(): gae_lambda=parsed_args.gae_lambda, entropy_coef=parsed_args.entropy_coeff, chkpt_dir=chkpt_path, - no_load=True + load=parsed_args.load ) # Episodes start @@ -173,6 +173,8 @@ def main(): metrics.plot_learning_curve(score_history, parsed_args.n_agents, figure_path) + metrics.plot_score(score_history, parsed_args.n_agents, figure_path) + metrics.plot_loss('actor', actor_loss, parsed_args.n_agents, figure_path) metrics.plot_loss('critic', critic_loss, parsed_args.n_agents, figure_path) diff --git a/utils/metrics.py b/utils/metrics.py index bd8e363..a7e339a 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -3,34 +3,6 @@ import numpy as np import matplotlib.pyplot as plt -def generate(parsed_args): - - # Setup parameter monitoring - score_history = np.zeros( - shape=(parsed_args.n_agents, parsed_args.n_episodes)) - - best_score = np.zeros(parsed_args.n_agents) - - actor_loss = np.zeros(shape=(parsed_args.n_agents, - parsed_args.n_episodes)) - - critic_loss = np.zeros(shape=(parsed_args.n_agents, - parsed_args.n_episodes)) - - total_loss = np.zeros(shape=(parsed_args.n_agents, - parsed_args.n_episodes)) - - entropy = np.zeros(shape=(parsed_args.n_agents, - parsed_args.n_episodes)) - - advantage = np.zeros(shape=(parsed_args.n_agents, - parsed_args.n_episodes)) - - return score_history, best_score, actor_loss, - critic_loss, total_loss, entropy, - advantage - - def plot_learning_curve(scores, num_players, figure_path): plt.figure()