diff --git a/plots/score_sp.png b/plots/score_sp.png index c42dccc..98c3bda 100644 Binary files a/plots/score_sp.png and b/plots/score_sp.png differ diff --git a/single-agent.py b/single-agent.py index 9e447fd..49441f8 100644 --- a/single-agent.py +++ b/single-agent.py @@ -1,4 +1,3 @@ -import random import torch as T import numpy as np import matplotlib.pyplot as plt @@ -24,8 +23,8 @@ game = Game(n_players) agent = game.level.player_sprites[0].agent score_history = np.zeros(shape=(game.max_num_players, n_episodes)) -best_score = np.zeros(game.max_num_players) -avg_score = np.zeros(game.max_num_players) +best_score = 0 +avg_score = np.zeros(n_episodes) for i in tqdm(range(n_episodes)): # TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste) @@ -55,22 +54,21 @@ for i in tqdm(range(n_episodes)): for player in game.level.player_sprites: exp_points = player.stats.exp score_history[player.player_id][i] = exp_points - avg_score[player.player_id] = np.mean( - score_history[player.player_id]) + avg_score[i] = np.mean(score_history) - if np.mean(avg_score) > np.mean(best_score): + if avg_score[i] >= best_score: print( - f"\nNew Best score: {np.mean(avg_score)}\ - \nOld Best score: {np.mean(best_score)}" + f"\nNew Best score: {avg_score[i]}\ + \nOld Best score: {best_score}" ) - best_score = avg_score + best_score = avg_score[i] print("Saving models for agent...") agent.save_models( actr_chkpt="player_actor", crtc_chkpt="player_critic") print("Models saved ...\n") else: print( - f"Average score of round: {np.mean(avg_score)}\ + f"Average score of round: {avg_score[i]}\ \nBest score: {np.mean(best_score)}" ) @@ -80,7 +78,7 @@ agent.save_models( actr_chkpt="player_actor", crtc_chkpt="player_critic") print("Models saved ...\n") -plt.plot(score_history) +plt.plot(avg_score) plt.savefig(figure_file) game.quit() diff --git a/tmp/ppo/player_actor b/tmp/ppo/player_actor index c046887..9e4e887 100644 Binary files a/tmp/ppo/player_actor and b/tmp/ppo/player_actor differ diff --git a/tmp/ppo/player_critic b/tmp/ppo/player_critic index cbfc2f7..3643a59 100644 Binary files a/tmp/ppo/player_critic and b/tmp/ppo/player_critic differ