Updated metrics correctly

This commit is contained in:
Vasilis Valatsos 2024-03-07 10:58:05 +02:00
parent 1ab8df01ea
commit 3324e092ef
2 changed files with 11 additions and 11 deletions

10
main.py
View file

@ -185,17 +185,17 @@ def main():
print(f"Models saved to {chkpt_path}") print(f"Models saved to {chkpt_path}")
metrics.plot_learning_curve(score_history, parsed_args.n_agents, figure_path) metrics.plot_learning_curve(score_history, parsed_args.n_agents, figure_path, n_episodes)
metrics.plot_score(score_history, parsed_args.n_agents, figure_path) metrics.plot_score(score_history, parsed_args.n_agents, figure_path)
metrics.plot_loss('actor', actor_loss, parsed_args.n_agents, figure_path) metrics.plot_loss('actor', actor_loss, parsed_args.n_agents, figure_path, n_episodes)
metrics.plot_loss('critic', critic_loss, parsed_args.n_agents, figure_path) metrics.plot_loss('critic', critic_loss, parsed_args.n_agents, figure_path, n_episodes)
metrics.plot_parameter('entropy', entropy, parsed_args.n_agents, figure_path) metrics.plot_parameter('entropy', entropy, parsed_args.n_agents, figure_path, n_episodes)
metrics.plot_parameter('advantage', advantage, parsed_args.n_agents, figure_path) metrics.plot_parameter('advantage', advantage, parsed_args.n_agents, figure_path, n_episodes)
metrics.plot_avg_time(time_alive, parsed_args.n_agents, figure_path) metrics.plot_avg_time(time_alive, parsed_args.n_agents, figure_path)

View file

@ -3,7 +3,7 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def plot_learning_curve(scores, num_players, figure_path, ep_lenght): def plot_learning_curve(scores, num_players, figure_path, n_episodes):
plt.figure() plt.figure()
plt.title("Running Average - Score") plt.title("Running Average - Score")
@ -13,7 +13,7 @@ def plot_learning_curve(scores, num_players, figure_path, ep_lenght):
for score in scores: for score in scores:
running_avg = np.zeros(len(score)) running_avg = np.zeros(len(score))
for i in range(len(score)): for i in range(len(score)):
running_avg[i] = np.mean(score[max(0, i-int(ep_length/10)):(i+1)]) running_avg[i] = np.mean(score[max(0, i-int(n_episodes/10)):(i+1)])
plt.plot(running_avg) plt.plot(running_avg)
plt.savefig(os.path.join(figure_path, "avg_score.png")) plt.savefig(os.path.join(figure_path, "avg_score.png"))
plt.close() plt.close()
@ -40,7 +40,7 @@ def plot_score(scores, num_players, figure_path):
plt.close() plt.close()
def plot_loss(nn_type, losses, num_players, figure_path, ep_length): def plot_loss(nn_type, losses, num_players, figure_path, n_episodes):
plt.figure() plt.figure()
plt.title(f"Running Average - {nn_type.capitalize()} Loss") plt.title(f"Running Average - {nn_type.capitalize()} Loss")
@ -50,13 +50,13 @@ def plot_loss(nn_type, losses, num_players, figure_path, ep_length):
for loss in losses: for loss in losses:
running_avg = np.zeros(len(loss)) running_avg = np.zeros(len(loss))
for i in range(len(loss)): for i in range(len(loss)):
running_avg[i] = np.mean(loss[max(0, i-int(ep_length/10)):(i+1)]) running_avg[i] = np.mean(loss[max(0, i-int(n_episodes/10)):(i+1)])
plt.plot(running_avg) plt.plot(running_avg)
plt.savefig(os.path.join(figure_path, f"{nn_type}_loss.png")) plt.savefig(os.path.join(figure_path, f"{nn_type}_loss.png"))
plt.close() plt.close()
def plot_parameter(name, parameter, num_players, figure_path, ep_length): def plot_parameter(name, parameter, num_players, figure_path, n_episodes):
plt.figure() plt.figure()
plt.title(f"Running Average - {name.capitalize()}") plt.title(f"Running Average - {name.capitalize()}")
@ -66,7 +66,7 @@ def plot_parameter(name, parameter, num_players, figure_path, ep_length):
for param in parameter: for param in parameter:
running_avg = np.zeros(len(param)) running_avg = np.zeros(len(param))
for i in range(len(param)): for i in range(len(param)):
running_avg[i] = np.mean(param[max(0, i-int(ep_length/10)):(i+1)]) running_avg[i] = np.mean(param[max(0, i-int(n_episodes/10)):(i+1)])
plt.plot(running_avg) plt.plot(running_avg)
plt.savefig(os.path.join(figure_path, f"{name}.png")) plt.savefig(os.path.join(figure_path, f"{name}.png"))
plt.close() plt.close()