More fixes to saving and reporting score
This commit is contained in:
parent
4e8dcb766f
commit
6a84d0b3f4
4 changed files with 9 additions and 11 deletions
Binary file not shown.
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 19 KiB |
|
@ -1,4 +1,3 @@
|
|||
import random
|
||||
import torch as T
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -24,8 +23,8 @@ game = Game(n_players)
|
|||
agent = game.level.player_sprites[0].agent
|
||||
|
||||
score_history = np.zeros(shape=(game.max_num_players, n_episodes))
|
||||
best_score = np.zeros(game.max_num_players)
|
||||
avg_score = np.zeros(game.max_num_players)
|
||||
best_score = 0
|
||||
avg_score = np.zeros(n_episodes)
|
||||
|
||||
for i in tqdm(range(n_episodes)):
|
||||
# TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste)
|
||||
|
@ -55,22 +54,21 @@ for i in tqdm(range(n_episodes)):
|
|||
for player in game.level.player_sprites:
|
||||
exp_points = player.stats.exp
|
||||
score_history[player.player_id][i] = exp_points
|
||||
avg_score[player.player_id] = np.mean(
|
||||
score_history[player.player_id])
|
||||
avg_score[i] = np.mean(score_history)
|
||||
|
||||
if np.mean(avg_score) > np.mean(best_score):
|
||||
if avg_score[i] >= best_score:
|
||||
print(
|
||||
f"\nNew Best score: {np.mean(avg_score)}\
|
||||
\nOld Best score: {np.mean(best_score)}"
|
||||
f"\nNew Best score: {avg_score[i]}\
|
||||
\nOld Best score: {best_score}"
|
||||
)
|
||||
best_score = avg_score
|
||||
best_score = avg_score[i]
|
||||
print("Saving models for agent...")
|
||||
agent.save_models(
|
||||
actr_chkpt="player_actor", crtc_chkpt="player_critic")
|
||||
print("Models saved ...\n")
|
||||
else:
|
||||
print(
|
||||
f"Average score of round: {np.mean(avg_score)}\
|
||||
f"Average score of round: {avg_score[i]}\
|
||||
\nBest score: {np.mean(best_score)}"
|
||||
)
|
||||
|
||||
|
@ -80,7 +78,7 @@ agent.save_models(
|
|||
actr_chkpt="player_actor", crtc_chkpt="player_critic")
|
||||
print("Models saved ...\n")
|
||||
|
||||
plt.plot(score_history)
|
||||
plt.plot(avg_score)
|
||||
plt.savefig(figure_file)
|
||||
game.quit()
|
||||
|
||||
|
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in a new issue