pneuma-pygame/main.py

60 lines
1.8 KiB
Python
Raw Normal View History

2023-11-19 03:27:47 +00:00
from game import Game
2023-11-17 02:19:03 +00:00
from tqdm import tqdm
2023-11-19 03:27:47 +00:00
from os import environ
environ['PYGAME_HIDE_SUPPORT_PROMPT'] = '1'
if __name__ == '__main__':
2023-11-19 03:27:47 +00:00
n_episodes = 1000
2023-11-17 02:19:03 +00:00
figure_file = 'plots/score.png'
score_history = []
best_score = 0
avg_score = 0
agent_list = []
2023-11-19 03:27:47 +00:00
game_len = 5000
game = Game()
2023-11-17 02:19:03 +00:00
2023-11-19 03:27:47 +00:00
for i in tqdm(range(n_episodes)):
2023-11-17 02:19:03 +00:00
# TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste)
2023-11-19 03:27:47 +00:00
if i != 0:
game.level.__init__(reset=True)
2023-11-17 02:19:03 +00:00
# TODO: Make game.level.reset_map() so we don't pull out and load the agent every time (There is -definitevly- a better way)
for player in game.level.player_sprites:
for player_id, agent in agent_list:
if player.player_id == player_id:
player.agent = agent
2023-11-19 03:27:47 +00:00
2023-11-17 02:19:03 +00:00
agent_list = []
done = False
score = 0
2023-11-19 03:27:47 +00:00
for _ in tqdm(range(game_len)):
2023-11-17 02:19:03 +00:00
if not game.level.done:
game.run()
else:
break
for player in game.level.player_sprites:
agent_list.append((player.player_id, player.agent))
2023-11-19 03:27:47 +00:00
if i == n_episodes-1 and game.level.enemy_sprites != []:
2023-11-17 02:19:03 +00:00
for player in game.level.player_sprites:
2023-11-19 03:27:47 +00:00
for enemy in game.level.enemy_sprites:
player.stats.exp -= 5
2023-11-17 02:19:03 +00:00
player.update()
2023-11-19 03:27:47 +00:00
for player in game.level.player_sprites:
player.agent.save_models()
2023-11-17 02:19:03 +00:00
# TODO: Make it so that scores appear here for each player
# score_history.append(game.level.player.score)
# print(score)
# avg_score = np.mean(score_history[-100:])
# if avg_score > best_score:
# best_score = avg_score
# game.level.player.agent.save_models()