diff --git a/agents/ppo/agent.py b/agents/ppo/agent.py index 69fb1cc..06899ea 100644 --- a/agents/ppo/agent.py +++ b/agents/ppo/agent.py @@ -15,28 +15,20 @@ class Agent: self.n_epochs = n_epochs self.gae_lambda = gae_lambda - print("Preparing Actor model...") self.actor = ActorNetwork(input_dims, n_actions, alpha) - print(f"Actor network activated using {self.actor.device}") - print("\nPreparing Critic model...") self.critic = CriticNetwork(input_dims, alpha) - print(f"Critic network activated using {self.critic.device}") self.memory = PPOMemory(batch_size) def remember(self, state, action, probs, vals, reward, done): self.memory.store_memory(state, action, probs, vals, reward, done) - def save_models(self): - print('... saving models ...') - self.actor.save_checkpoint() - self.critic.save_checkpoint() - print('... done ...') + def save_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'): + self.actor.save_checkpoint(actr_chkpt) + self.critic.save_checkpoint(crtc_chkpt) - def load_models(self): - print('... loading models ...') - self.actor.load_checkpoint() - self.critic.load_checkpoint() - print('.. done ...') + def load_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'): + self.actor.load_checkpoint(actr_chkpt) + self.critic.load_checkpoint(crtc_chkpt) def choose_action(self, observation): state = T.tensor(observation, dtype=T.float).to(self.actor.device) @@ -56,7 +48,7 @@ class Agent: state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.generate_batches() values = vals_arr - advantage = np.zeros(len(reward_arr), dtype=np.float32) + advantage = np.zeros(len(reward_arr), dtype=np.float64) for t in range(len(reward_arr)-1): discount = 1 diff --git a/agents/ppo/brain.py b/agents/ppo/brain.py index 770ff1a..3ab73b3 100644 --- a/agents/ppo/brain.py +++ b/agents/ppo/brain.py @@ -55,7 +55,7 @@ class ActorNetwork(nn.Module): def __init__(self, input_dim, output_dim, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'): super(ActorNetwork, self).__init__() - self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo') + self.chkpt_dir = chkpt_dir self.actor = nn.Sequential( nn.Linear(input_dim, fc1_dims), @@ -68,8 +68,7 @@ class ActorNetwork(nn.Module): self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5) - self.device = T.device('cuda:0' if T.cuda.is_available() else ( - 'mps' if T.backends.mps.is_available() else 'cpu')) + self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu') self.to(self.device) @@ -79,11 +78,11 @@ class ActorNetwork(nn.Module): return dist - def save_checkpoint(self): - T.save(self.state_dict(), self.checkpoint_file) + def save_checkpoint(self, filename = 'actor_torch_ppo'): + T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename)) - def load_checkpoint(self): - self.load_state_dict(T.load(self.checkpoint_file)) + def load_checkpoint(self, filename = 'actor_torch_ppo'): + self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename))) class CriticNetwork(nn.Module): @@ -91,7 +90,7 @@ class CriticNetwork(nn.Module): def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'): super(CriticNetwork, self).__init__() - self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo') + self.chkpt_dir = chkpt_dir self.critic = nn.Sequential( nn.Linear(input_dims, fc1_dims), @@ -103,8 +102,7 @@ class CriticNetwork(nn.Module): self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5) - self.device = T.device('cuda:0' if T.cuda.is_available() else ( - 'mps' if T.backends.mps.is_available() else 'cpu')) + self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu') self.to(self.device) @@ -112,8 +110,8 @@ class CriticNetwork(nn.Module): value = self.critic(state) return value - def save_checkpoint(self): - T.save(self.state_dict(), self.checkpoint_file) + def save_checkpoint(self, filename = 'critic_torch_ppo'): + T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename)) - def load_checkpoint(self): - self.load_state_dict(T.load(self.checkpoint_file)) + def load_checkpoint(self, filename = 'critic_torch_ppo'): + self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename))) diff --git a/configs/game/monster_config.py b/configs/game/monster_config.py index 58b30c2..0b77079 100644 --- a/configs/game/monster_config.py +++ b/configs/game/monster_config.py @@ -9,4 +9,4 @@ monster_data = { 'squid': {'id': 1, 'health': .1, 'exp': 1, 'attack': .5, 'attack_type': 'slash', 'speed': 3, 'knockback': 20, 'attack_radius': 80, 'notice_radius': 360}, 'raccoon': {'id': 2, 'health': .3, 'exp': 2.5, 'attack': .8, 'attack_type': 'claw', 'speed': 2, 'knockback': 20, 'attack_radius': 120, 'notice_radius': 400}, 'spirit': {'id': 3, 'health': .1, 'exp': 1.1, 'attack': .6, 'attack_type': 'thunder', 'speed': 4, 'knockback': 20, 'attack_radius': 60, 'notice_radius': 350}, - 'bamboo': {'id': 4, 'health': .07, 'exp': 1.2, 'attack': .4, 'attack_type': 'leaf_attack', 'speed': 3, 'knockback': 20, 'attack_radius': 50, 'notice_radius': 300}} + 'bamboo': {'id': 4, 'health': .07, 'exp': 1.2, 'attack': .2, 'attack_type': 'leaf_attack', 'speed': 3, 'knockback': 20, 'attack_radius': 50, 'notice_radius': 300}} diff --git a/entities/player.py b/entities/player.py index ee232b4..59c50ee 100644 --- a/entities/player.py +++ b/entities/player.py @@ -120,9 +120,9 @@ class Player(pygame.sprite.Sprite): self.reward_features = [ self.stats.exp, - np.exp(-nearest_dist**2), - np.exp(-nearest_enemy.stats.health**2), - -np.exp(-self.stats.health) + 2*np.exp(-nearest_dist**2), + np.exp(-nearest_enemy.stats.health), + -np.exp(-self.stats.health**2) ] self.state_features = [ @@ -158,20 +158,21 @@ class Player(pygame.sprite.Sprite): self.num_features = len(self.state_features) def setup_agent(self): + print(f"Initializing Agent {self.player_id} ...") self.agent = Agent( input_dims=len(self.state_features), n_actions=len(self._input.possible_actions), batch_size=5, n_epochs=4) + print(f" Agent initialized using {self.agent.actor.device}. Attempting to load models ...") try: - self.agent.load_models() + self.agent.load_models(actr_chkpt = f"player_{self.player_id}_actor", crtc_chkpt = f"player_{self.player_id}_critic") + print("Models loaded ...\n") except FileNotFoundError: - print("FileNotFoundError for agent.load_model().\ - Skipping loading...") + print("FileNotFound for agent. Skipping loading...\n") def is_dead(self): if self.stats.health <= 0: - self.stats.exp = max(0, self.stats.exp - .5) return True else: return False @@ -202,9 +203,6 @@ class Player(pygame.sprite.Sprite): self.get_current_state() - if self.is_dead(): - self.agent.learn() - # Refresh objects based on input self.status = self._input.status @@ -218,3 +216,6 @@ class Player(pygame.sprite.Sprite): self.stats.health_recovery() self.stats.energy_recovery() self._input.cooldowns(self._input.combat.vulnerable) + + if self.is_dead(): + self.stats.exp = max(-1, self.stats.exp - .5) diff --git a/game.py b/game.py index c178d4a..174d5af 100644 --- a/game.py +++ b/game.py @@ -12,7 +12,7 @@ class Game: pygame.init() self.screen = pygame.display.set_mode( - (WIDTH, HEIGHT)) + (WIDTH, HEIGHT), pygame.HIDDEN) pygame.display.set_caption('Pneuma') @@ -22,8 +22,11 @@ class Game: self.level = Level() + self.max_num_players = len(self.level.player_sprites) + def calc_score(self): - self.scores = [0 for _ in range(len(self.level.player_sprites))] + + self.scores = [0 for _ in range(self.max_num_players)] for player in self.level.player_sprites: self.scores[player.player_id] = player.stats.exp @@ -39,7 +42,7 @@ class Game: self.screen.fill(WATER_COLOR) - self.level.run(who='observer') + self.level.run() pygame.display.update() self.clock.tick(FPS) diff --git a/level/level.py b/level/level.py index be088a0..0a95ca1 100644 --- a/level/level.py +++ b/level/level.py @@ -172,17 +172,9 @@ class Level: if who == 'observer': self.visible_sprites.custom_draw(self.observer) self.ui.display(self.observer) - elif who == 'player': - self.visible_sprites.custom_draw(self.player) - self.ui.display(self.player) debug('v0.8') - for player in self.player_sprites: - if player.is_dead(): - self.dead_players[player.player_id] = True - - self.done = True if self.dead_players.all() == 1 else False if not self.game_paused: # Update the game @@ -191,8 +183,16 @@ class Level: self.get_players_enemies() self.get_distance_direction() - self.apply_damage_to_player() self.visible_sprites.update() + self.apply_damage_to_player() else: debug('PAUSED') + + for player in self.player_sprites: + if player.is_dead(): + self.dead_players[player.player_id] = True + + self.done = True if self.dead_players.all() == 1 else False + + diff --git a/main.py b/main.py index a506d59..94312f2 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,5 @@ +import random +import torch as T import numpy as np import matplotlib.pyplot as plt @@ -8,68 +10,74 @@ from os import environ environ['PYGAME_HIDE_SUPPORT_PROMPT'] = '1' -if __name__ == '__main__': +random.seed(1) +np.random.seed(1) +T.manual_seed(1) - n_episodes = 1000 - game_len = 10000 +n_episodes = 1000 +game_len = 10000 - figure_file = 'plots/score.png' - best_score = 0 - avg_score = 0 +figure_file = 'plots/score.png' - game = Game() +game = Game() - agent_list = [] - exp_points_list = [] - score_history = np.zeros( - shape=(len(game.level.player_sprites), n_episodes, )) - best_score = np.zeros(len(game.level.player_sprites)) - avg_score = np.zeros(len(game.level.player_sprites)) - for i in tqdm(range(n_episodes)): - # TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste) - if i != 0: - game.level.__init__(reset=True) - # TODO: Make game.level.reset_map() so we don't pull out and load the agent every time (There is -definitevly- a better way) - for player in game.level.player_sprites: - for agent in agent_list: - player.agent = agent_list[player.player_id] - player.stats.exp = score_history[player.player_id][i-1] +agent_list = [0 for _ in range(game.max_num_players)] - agent_list = [] - - for j in range(game_len): - if not game.level.done: - - game.run() - game.calc_score() - - if (j == game_len-1 or game.level.done) and game.level.enemy_sprites != []: - for player in game.level.player_sprites: - for enemy in game.level.enemy_sprites: - player.stats.exp *= .95 - else: - break +score_history = np.zeros( + shape=(game.max_num_players, n_episodes)) +best_score = np.zeros(game.max_num_players) +avg_score = np.zeros(game.max_num_players) +for i in tqdm(range(n_episodes)): + # TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste) + if i != 0: + game.level.__init__(reset=True) + # TODO: Make game.level.reset_map() so we don't pull out and load the agent every time (There is -definitevly- a better way) for player in game.level.player_sprites: - agent_list.append(player.agent) - exp_points = player.stats.exp - score_history[player.player_id][i] = exp_points - avg_score[player.player_id] = np.mean( - score_history[player.player_id]) - if avg_score[player.player_id] >= best_score[player.player_id]: - player.agent.save_models() - best_score[player.player_id] = avg_score[player.player_id] + for agent in agent_list: + player.agent = agent_list[player.player_id] + player.stats.exp = score_history[player.player_id][i-1] - print( - f"\nCumulative score for player {player.player_id}:\ - {score_history[0][i]}\ - \nAverage score for player {player.player_id}:\ - {avg_score[player.player_id]}\ - \nBest score for player {player.player_id}:\ - {best_score[player.player_id]}") + agent_list = [0 for _ in range(game.max_num_players)] - plt.plot(score_history[0]) + for j in range(game_len): + if not game.level.done: - game.quit() + game.run() + game.calc_score() - plt.show() + for player in game.level.player_sprites: + if player.is_dead(): + agent_list[player.player_id] = player.agent + player.kill() + + # if (j == game_len-1 or game.level.done) and game.level.enemy_sprites != []: + # for player in game.level.player_sprites: + # for enemy in game.level.enemy_sprites: + # player.stats.exp *= .95 + else: + break + + for player in game.level.player_sprites: + if not player.is_dead(): + agent_list[player.player_id] = player.agent + exp_points = player.stats.exp + score_history[player.player_id][i] = exp_points + avg_score[player.player_id] = np.mean( + score_history[player.player_id]) + if avg_score[player.player_id] > best_score[player.player_id]: + best_score[player.player_id] = avg_score[player.player_id] + print(f"Saving models for agent {player.player_id}...") + player.agent.save_models(actr_chkpt = f"player_{player.player_id}_actor", crtc_chkpt = f"player_{player.player_id}_critic") + print("Models saved ...\n") + + print( + f"\nCumulative score for player {player.player_id}: {score_history[0][i]}\nAverage score for player {player.player_id}: {avg_score[player.player_id]}\nBest score for player {player.player_id}: {best_score[player.player_id]}") + + + +plt.plot(score_history[0]) + +game.quit() + +plt.show() diff --git a/tmp/ppo/actor_torch_ppo b/tmp/ppo/actor_torch_ppo deleted file mode 100644 index 9876a62..0000000 Binary files a/tmp/ppo/actor_torch_ppo and /dev/null differ diff --git a/tmp/ppo/critic_torch_ppo b/tmp/ppo/critic_torch_ppo deleted file mode 100644 index d5b6a6f..0000000 Binary files a/tmp/ppo/critic_torch_ppo and /dev/null differ diff --git a/tmp/ppo/player_2_actor b/tmp/ppo/player_2_actor new file mode 100644 index 0000000..aa34020 Binary files /dev/null and b/tmp/ppo/player_2_actor differ diff --git a/tmp/ppo/player_2_critic b/tmp/ppo/player_2_critic new file mode 100644 index 0000000..23e2717 Binary files /dev/null and b/tmp/ppo/player_2_critic differ diff --git a/tmp/ppo/player_3_actor b/tmp/ppo/player_3_actor new file mode 100644 index 0000000..fc3b100 Binary files /dev/null and b/tmp/ppo/player_3_actor differ diff --git a/tmp/ppo/player_3_critic b/tmp/ppo/player_3_critic new file mode 100644 index 0000000..017f4eb Binary files /dev/null and b/tmp/ppo/player_3_critic differ diff --git a/tmp/ppo/player_4_actor b/tmp/ppo/player_4_actor new file mode 100644 index 0000000..f448d6c Binary files /dev/null and b/tmp/ppo/player_4_actor differ diff --git a/tmp/ppo/player_4_critic b/tmp/ppo/player_4_critic new file mode 100644 index 0000000..752dcb3 Binary files /dev/null and b/tmp/ppo/player_4_critic differ