2023-11-23 15:37:02 +00:00
|
|
|
import random
|
|
|
|
import torch as T
|
2023-11-23 11:44:23 +00:00
|
|
|
import numpy as np
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
2023-11-19 03:27:47 +00:00
|
|
|
from game import Game
|
2023-11-17 02:19:03 +00:00
|
|
|
from tqdm import tqdm
|
2023-09-27 18:03:37 +00:00
|
|
|
|
2023-11-19 03:27:47 +00:00
|
|
|
from os import environ
|
|
|
|
environ['PYGAME_HIDE_SUPPORT_PROMPT'] = '1'
|
2023-09-27 18:03:37 +00:00
|
|
|
|
|
|
|
|
2023-11-23 15:37:02 +00:00
|
|
|
np.random.seed(1)
|
|
|
|
T.manual_seed(1)
|
2023-11-23 11:44:23 +00:00
|
|
|
|
2023-11-25 16:23:15 +00:00
|
|
|
n_episodes = 300
|
|
|
|
game_len = 10000
|
2023-11-25 11:47:07 +00:00
|
|
|
n_players = 8
|
2023-11-17 02:19:03 +00:00
|
|
|
|
2023-11-24 14:31:01 +00:00
|
|
|
figure_file = 'plots/score_sp.png'
|
2023-11-17 02:19:03 +00:00
|
|
|
|
2023-11-25 11:47:07 +00:00
|
|
|
game = Game(n_players)
|
2023-11-23 15:37:02 +00:00
|
|
|
|
2023-11-24 18:33:47 +00:00
|
|
|
agent = game.level.player_sprites[0].agent
|
2023-11-23 15:37:02 +00:00
|
|
|
|
2023-11-24 14:31:01 +00:00
|
|
|
score_history = np.zeros(shape=(game.max_num_players, n_episodes))
|
2023-11-23 15:37:02 +00:00
|
|
|
best_score = np.zeros(game.max_num_players)
|
|
|
|
avg_score = np.zeros(game.max_num_players)
|
|
|
|
|
|
|
|
for i in tqdm(range(n_episodes)):
|
|
|
|
# TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste)
|
|
|
|
if i != 0:
|
2023-11-25 11:47:07 +00:00
|
|
|
game.level.__init__(n_players, reset=True)
|
2023-11-23 15:37:02 +00:00
|
|
|
# TODO: Make game.level.reset_map() so we don't pull out and load the agent every time (There is -definitevly- a better way)
|
|
|
|
|
2023-11-24 14:40:34 +00:00
|
|
|
for player in game.level.player_sprites:
|
|
|
|
player.stats.exp = score_history[player.player_id][i-1]
|
2023-11-24 18:33:47 +00:00
|
|
|
player.agent = agent
|
2023-11-23 15:37:02 +00:00
|
|
|
|
2023-11-24 20:15:50 +00:00
|
|
|
for j in tqdm(range(game_len)):
|
2023-11-23 15:37:02 +00:00
|
|
|
if not game.level.done:
|
|
|
|
|
|
|
|
game.run()
|
|
|
|
game.calc_score()
|
2023-11-17 02:19:03 +00:00
|
|
|
|
2023-11-23 11:44:23 +00:00
|
|
|
for player in game.level.player_sprites:
|
2023-11-23 15:37:02 +00:00
|
|
|
if player.is_dead():
|
|
|
|
player.kill()
|
2023-11-19 03:27:47 +00:00
|
|
|
|
2023-11-23 15:37:02 +00:00
|
|
|
# if (j == game_len-1 or game.level.done) and game.level.enemy_sprites != []:
|
|
|
|
# for player in game.level.player_sprites:
|
|
|
|
# for enemy in game.level.enemy_sprites:
|
|
|
|
# player.stats.exp *= .95
|
2023-11-23 11:44:23 +00:00
|
|
|
|
2023-11-23 15:37:02 +00:00
|
|
|
for player in game.level.player_sprites:
|
|
|
|
exp_points = player.stats.exp
|
|
|
|
score_history[player.player_id][i] = exp_points
|
|
|
|
avg_score[player.player_id] = np.mean(
|
|
|
|
score_history[player.player_id])
|
2023-11-24 18:33:20 +00:00
|
|
|
|
2023-11-24 14:31:01 +00:00
|
|
|
if np.mean(avg_score) > np.mean(best_score):
|
2023-11-25 11:47:07 +00:00
|
|
|
print(
|
|
|
|
f"\nNew Best score: {np.mean(avg_score)}\
|
2023-11-25 16:20:43 +00:00
|
|
|
\nOld Best score: {np.mean(best_score)}"
|
|
|
|
)
|
2023-11-24 14:31:01 +00:00
|
|
|
best_score = avg_score
|
|
|
|
print("Saving models for agent...")
|
2023-11-25 16:20:43 +00:00
|
|
|
agent.save_models(
|
2023-11-24 14:31:01 +00:00
|
|
|
actr_chkpt="player_actor", crtc_chkpt="player_critic")
|
|
|
|
print("Models saved ...\n")
|
2023-11-25 16:20:43 +00:00
|
|
|
else:
|
|
|
|
print(
|
|
|
|
f"Average score of round: {np.mean(avg_score)}\
|
|
|
|
\nBest score: {np.mean(best_score)}"
|
|
|
|
)
|
2023-11-23 11:44:23 +00:00
|
|
|
|
|
|
|
|
2023-11-25 11:47:07 +00:00
|
|
|
print("\nEpisodes done, saving models...")
|
2023-11-25 16:20:43 +00:00
|
|
|
agent.save_models(
|
2023-11-25 11:47:07 +00:00
|
|
|
actr_chkpt="player_actor", crtc_chkpt="player_critic")
|
|
|
|
print("Models saved ...\n")
|
2023-11-23 15:37:02 +00:00
|
|
|
|
2023-11-24 14:31:01 +00:00
|
|
|
plt.plot(score_history)
|
|
|
|
plt.savefig(figure_file)
|
2023-11-23 15:37:02 +00:00
|
|
|
game.quit()
|
|
|
|
|
|
|
|
plt.show()
|