More fixes to saving and reporting score
This commit is contained in:
parent
4e8dcb766f
commit
6a84d0b3f4
4 changed files with 9 additions and 11 deletions
Binary file not shown.
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 19 KiB |
|
@ -1,4 +1,3 @@
|
||||||
import random
|
|
||||||
import torch as T
|
import torch as T
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -24,8 +23,8 @@ game = Game(n_players)
|
||||||
agent = game.level.player_sprites[0].agent
|
agent = game.level.player_sprites[0].agent
|
||||||
|
|
||||||
score_history = np.zeros(shape=(game.max_num_players, n_episodes))
|
score_history = np.zeros(shape=(game.max_num_players, n_episodes))
|
||||||
best_score = np.zeros(game.max_num_players)
|
best_score = 0
|
||||||
avg_score = np.zeros(game.max_num_players)
|
avg_score = np.zeros(n_episodes)
|
||||||
|
|
||||||
for i in tqdm(range(n_episodes)):
|
for i in tqdm(range(n_episodes)):
|
||||||
# TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste)
|
# TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste)
|
||||||
|
@ -55,22 +54,21 @@ for i in tqdm(range(n_episodes)):
|
||||||
for player in game.level.player_sprites:
|
for player in game.level.player_sprites:
|
||||||
exp_points = player.stats.exp
|
exp_points = player.stats.exp
|
||||||
score_history[player.player_id][i] = exp_points
|
score_history[player.player_id][i] = exp_points
|
||||||
avg_score[player.player_id] = np.mean(
|
avg_score[i] = np.mean(score_history)
|
||||||
score_history[player.player_id])
|
|
||||||
|
|
||||||
if np.mean(avg_score) > np.mean(best_score):
|
if avg_score[i] >= best_score:
|
||||||
print(
|
print(
|
||||||
f"\nNew Best score: {np.mean(avg_score)}\
|
f"\nNew Best score: {avg_score[i]}\
|
||||||
\nOld Best score: {np.mean(best_score)}"
|
\nOld Best score: {best_score}"
|
||||||
)
|
)
|
||||||
best_score = avg_score
|
best_score = avg_score[i]
|
||||||
print("Saving models for agent...")
|
print("Saving models for agent...")
|
||||||
agent.save_models(
|
agent.save_models(
|
||||||
actr_chkpt="player_actor", crtc_chkpt="player_critic")
|
actr_chkpt="player_actor", crtc_chkpt="player_critic")
|
||||||
print("Models saved ...\n")
|
print("Models saved ...\n")
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f"Average score of round: {np.mean(avg_score)}\
|
f"Average score of round: {avg_score[i]}\
|
||||||
\nBest score: {np.mean(best_score)}"
|
\nBest score: {np.mean(best_score)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -80,7 +78,7 @@ agent.save_models(
|
||||||
actr_chkpt="player_actor", crtc_chkpt="player_critic")
|
actr_chkpt="player_actor", crtc_chkpt="player_critic")
|
||||||
print("Models saved ...\n")
|
print("Models saved ...\n")
|
||||||
|
|
||||||
plt.plot(score_history)
|
plt.plot(avg_score)
|
||||||
plt.savefig(figure_file)
|
plt.savefig(figure_file)
|
||||||
game.quit()
|
game.quit()
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in a new issue