Added score_metric and loading options

This commit is contained in:
Vasilis Valatsos 2024-02-13 11:48:24 +01:00
parent aaaf7a2829
commit 7eb7228a8c
4 changed files with 13 additions and 34 deletions

View file

@ -13,7 +13,7 @@ def parse_args():
parser.add_argument('--no_seed', parser.add_argument('--no_seed',
default=False, default=False,
action="store_true", action="store_true",
help="Set to True to run without a seed.") help="Set to run without a seed.")
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
@ -83,6 +83,11 @@ def parse_args():
action="store_true", action="store_true",
help="Set flag to disable learning. Useful for viewing trained agents interact in the environment.") help="Set flag to disable learning. Useful for viewing trained agents interact in the environment.")
parser.add_argument('--load',
type=int,
default=None,
help="Run id to load within chkpt_path")
parser.add_argument('--show_pg', parser.add_argument('--show_pg',
default=False, default=False,
action="store_true", action="store_true",

View file

@ -61,7 +61,7 @@ class Player(pygame.sprite.Sprite):
gae_lambda, gae_lambda,
chkpt_dir, chkpt_dir,
entropy_coef, entropy_coef,
no_load=False): load=None):
self.max_num_enemies = len(self.distance_direction_from_enemy) self.max_num_enemies = len(self.distance_direction_from_enemy)
self.get_current_state() self.get_current_state()
@ -82,12 +82,12 @@ class Player(pygame.sprite.Sprite):
print( print(
f"\nAgent initialized on player {self.player_id} using {self.agent.actor.device}.") f"\nAgent initialized on player {self.player_id} using {self.agent.actor.device}.")
if not no_load: if load:
print("Attempting to load models ...") print("Attempting to load models ...")
try: try:
self.agent.load_models( self.agent.load_models(
actr_chkpt=f"A{self.player_id}", actr_chkpt=f"run{load}/A{self.player_id}",
crtc_chkpt=f"C{self.player_id}" crtc_chkpt=f"run{load}/C{self.player_id}"
) )
print("Models loaded ...\n") print("Models loaded ...\n")

View file

@ -80,7 +80,7 @@ def main():
gae_lambda=parsed_args.gae_lambda, gae_lambda=parsed_args.gae_lambda,
entropy_coef=parsed_args.entropy_coeff, entropy_coef=parsed_args.entropy_coeff,
chkpt_dir=chkpt_path, chkpt_dir=chkpt_path,
no_load=True load=parsed_args.load
) )
# Episodes start # Episodes start
@ -173,6 +173,8 @@ def main():
metrics.plot_learning_curve(score_history, parsed_args.n_agents, figure_path) metrics.plot_learning_curve(score_history, parsed_args.n_agents, figure_path)
metrics.plot_score(score_history, parsed_args.n_agents, figure_path)
metrics.plot_loss('actor', actor_loss, parsed_args.n_agents, figure_path) metrics.plot_loss('actor', actor_loss, parsed_args.n_agents, figure_path)
metrics.plot_loss('critic', critic_loss, parsed_args.n_agents, figure_path) metrics.plot_loss('critic', critic_loss, parsed_args.n_agents, figure_path)

View file

@ -3,34 +3,6 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def generate(parsed_args):
# Setup parameter monitoring
score_history = np.zeros(
shape=(parsed_args.n_agents, parsed_args.n_episodes))
best_score = np.zeros(parsed_args.n_agents)
actor_loss = np.zeros(shape=(parsed_args.n_agents,
parsed_args.n_episodes))
critic_loss = np.zeros(shape=(parsed_args.n_agents,
parsed_args.n_episodes))
total_loss = np.zeros(shape=(parsed_args.n_agents,
parsed_args.n_episodes))
entropy = np.zeros(shape=(parsed_args.n_agents,
parsed_args.n_episodes))
advantage = np.zeros(shape=(parsed_args.n_agents,
parsed_args.n_episodes))
return score_history, best_score, actor_loss,
critic_loss, total_loss, entropy,
advantage
def plot_learning_curve(scores, num_players, figure_path): def plot_learning_curve(scores, num_players, figure_path):
plt.figure() plt.figure()