diff --git a/main.py b/main.py index 36ee202..3af7301 100644 --- a/main.py +++ b/main.py @@ -185,17 +185,17 @@ def main(): print(f"Models saved to {chkpt_path}") - metrics.plot_learning_curve(score_history, parsed_args.n_agents, figure_path) + metrics.plot_learning_curve(score_history, parsed_args.n_agents, figure_path, n_episodes) metrics.plot_score(score_history, parsed_args.n_agents, figure_path) - metrics.plot_loss('actor', actor_loss, parsed_args.n_agents, figure_path) + metrics.plot_loss('actor', actor_loss, parsed_args.n_agents, figure_path, n_episodes) - metrics.plot_loss('critic', critic_loss, parsed_args.n_agents, figure_path) + metrics.plot_loss('critic', critic_loss, parsed_args.n_agents, figure_path, n_episodes) - metrics.plot_parameter('entropy', entropy, parsed_args.n_agents, figure_path) + metrics.plot_parameter('entropy', entropy, parsed_args.n_agents, figure_path, n_episodes) - metrics.plot_parameter('advantage', advantage, parsed_args.n_agents, figure_path) + metrics.plot_parameter('advantage', advantage, parsed_args.n_agents, figure_path, n_episodes) metrics.plot_avg_time(time_alive, parsed_args.n_agents, figure_path) diff --git a/utils/metrics.py b/utils/metrics.py index d1f206e..11bfc8e 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -3,7 +3,7 @@ import numpy as np import matplotlib.pyplot as plt -def plot_learning_curve(scores, num_players, figure_path, ep_lenght): +def plot_learning_curve(scores, num_players, figure_path, n_episodes): plt.figure() plt.title("Running Average - Score") @@ -13,7 +13,7 @@ def plot_learning_curve(scores, num_players, figure_path, ep_lenght): for score in scores: running_avg = np.zeros(len(score)) for i in range(len(score)): - running_avg[i] = np.mean(score[max(0, i-int(ep_length/10)):(i+1)]) + running_avg[i] = np.mean(score[max(0, i-int(n_episodes/10)):(i+1)]) plt.plot(running_avg) plt.savefig(os.path.join(figure_path, "avg_score.png")) plt.close() @@ -40,7 +40,7 @@ def plot_score(scores, num_players, figure_path): plt.close() -def plot_loss(nn_type, losses, num_players, figure_path, ep_length): +def plot_loss(nn_type, losses, num_players, figure_path, n_episodes): plt.figure() plt.title(f"Running Average - {nn_type.capitalize()} Loss") @@ -50,13 +50,13 @@ def plot_loss(nn_type, losses, num_players, figure_path, ep_length): for loss in losses: running_avg = np.zeros(len(loss)) for i in range(len(loss)): - running_avg[i] = np.mean(loss[max(0, i-int(ep_length/10)):(i+1)]) + running_avg[i] = np.mean(loss[max(0, i-int(n_episodes/10)):(i+1)]) plt.plot(running_avg) plt.savefig(os.path.join(figure_path, f"{nn_type}_loss.png")) plt.close() -def plot_parameter(name, parameter, num_players, figure_path, ep_length): +def plot_parameter(name, parameter, num_players, figure_path, n_episodes): plt.figure() plt.title(f"Running Average - {name.capitalize()}") @@ -66,7 +66,7 @@ def plot_parameter(name, parameter, num_players, figure_path, ep_length): for param in parameter: running_avg = np.zeros(len(param)) for i in range(len(param)): - running_avg[i] = np.mean(param[max(0, i-int(ep_length/10)):(i+1)]) + running_avg[i] = np.mean(param[max(0, i-int(n_episodes/10)):(i+1)]) plt.plot(running_avg) plt.savefig(os.path.join(figure_path, f"{name}.png")) plt.close()