Fixed errors for MARL

This commit is contained in:
Vasilis Valatsos 2023-11-23 16:37:02 +01:00
parent 1a6ed25673
commit 8809c1b06c
15 changed files with 108 additions and 106 deletions

View file

@ -15,28 +15,20 @@ class Agent:
self.n_epochs = n_epochs self.n_epochs = n_epochs
self.gae_lambda = gae_lambda self.gae_lambda = gae_lambda
print("Preparing Actor model...")
self.actor = ActorNetwork(input_dims, n_actions, alpha) self.actor = ActorNetwork(input_dims, n_actions, alpha)
print(f"Actor network activated using {self.actor.device}")
print("\nPreparing Critic model...")
self.critic = CriticNetwork(input_dims, alpha) self.critic = CriticNetwork(input_dims, alpha)
print(f"Critic network activated using {self.critic.device}")
self.memory = PPOMemory(batch_size) self.memory = PPOMemory(batch_size)
def remember(self, state, action, probs, vals, reward, done): def remember(self, state, action, probs, vals, reward, done):
self.memory.store_memory(state, action, probs, vals, reward, done) self.memory.store_memory(state, action, probs, vals, reward, done)
def save_models(self): def save_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'):
print('... saving models ...') self.actor.save_checkpoint(actr_chkpt)
self.actor.save_checkpoint() self.critic.save_checkpoint(crtc_chkpt)
self.critic.save_checkpoint()
print('... done ...')
def load_models(self): def load_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'):
print('... loading models ...') self.actor.load_checkpoint(actr_chkpt)
self.actor.load_checkpoint() self.critic.load_checkpoint(crtc_chkpt)
self.critic.load_checkpoint()
print('.. done ...')
def choose_action(self, observation): def choose_action(self, observation):
state = T.tensor(observation, dtype=T.float).to(self.actor.device) state = T.tensor(observation, dtype=T.float).to(self.actor.device)
@ -56,7 +48,7 @@ class Agent:
state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.generate_batches() state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.generate_batches()
values = vals_arr values = vals_arr
advantage = np.zeros(len(reward_arr), dtype=np.float32) advantage = np.zeros(len(reward_arr), dtype=np.float64)
for t in range(len(reward_arr)-1): for t in range(len(reward_arr)-1):
discount = 1 discount = 1

View file

@ -55,7 +55,7 @@ class ActorNetwork(nn.Module):
def __init__(self, input_dim, output_dim, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'): def __init__(self, input_dim, output_dim, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
super(ActorNetwork, self).__init__() super(ActorNetwork, self).__init__()
self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo') self.chkpt_dir = chkpt_dir
self.actor = nn.Sequential( self.actor = nn.Sequential(
nn.Linear(input_dim, fc1_dims), nn.Linear(input_dim, fc1_dims),
@ -68,8 +68,7 @@ class ActorNetwork(nn.Module):
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5) self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
self.device = T.device('cuda:0' if T.cuda.is_available() else ( self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
'mps' if T.backends.mps.is_available() else 'cpu'))
self.to(self.device) self.to(self.device)
@ -79,11 +78,11 @@ class ActorNetwork(nn.Module):
return dist return dist
def save_checkpoint(self): def save_checkpoint(self, filename = 'actor_torch_ppo'):
T.save(self.state_dict(), self.checkpoint_file) T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
def load_checkpoint(self): def load_checkpoint(self, filename = 'actor_torch_ppo'):
self.load_state_dict(T.load(self.checkpoint_file)) self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename)))
class CriticNetwork(nn.Module): class CriticNetwork(nn.Module):
@ -91,7 +90,7 @@ class CriticNetwork(nn.Module):
def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'): def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
super(CriticNetwork, self).__init__() super(CriticNetwork, self).__init__()
self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo') self.chkpt_dir = chkpt_dir
self.critic = nn.Sequential( self.critic = nn.Sequential(
nn.Linear(input_dims, fc1_dims), nn.Linear(input_dims, fc1_dims),
@ -103,8 +102,7 @@ class CriticNetwork(nn.Module):
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5) self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
self.device = T.device('cuda:0' if T.cuda.is_available() else ( self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
'mps' if T.backends.mps.is_available() else 'cpu'))
self.to(self.device) self.to(self.device)
@ -112,8 +110,8 @@ class CriticNetwork(nn.Module):
value = self.critic(state) value = self.critic(state)
return value return value
def save_checkpoint(self): def save_checkpoint(self, filename = 'critic_torch_ppo'):
T.save(self.state_dict(), self.checkpoint_file) T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
def load_checkpoint(self): def load_checkpoint(self, filename = 'critic_torch_ppo'):
self.load_state_dict(T.load(self.checkpoint_file)) self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename)))

View file

@ -9,4 +9,4 @@ monster_data = {
'squid': {'id': 1, 'health': .1, 'exp': 1, 'attack': .5, 'attack_type': 'slash', 'speed': 3, 'knockback': 20, 'attack_radius': 80, 'notice_radius': 360}, 'squid': {'id': 1, 'health': .1, 'exp': 1, 'attack': .5, 'attack_type': 'slash', 'speed': 3, 'knockback': 20, 'attack_radius': 80, 'notice_radius': 360},
'raccoon': {'id': 2, 'health': .3, 'exp': 2.5, 'attack': .8, 'attack_type': 'claw', 'speed': 2, 'knockback': 20, 'attack_radius': 120, 'notice_radius': 400}, 'raccoon': {'id': 2, 'health': .3, 'exp': 2.5, 'attack': .8, 'attack_type': 'claw', 'speed': 2, 'knockback': 20, 'attack_radius': 120, 'notice_radius': 400},
'spirit': {'id': 3, 'health': .1, 'exp': 1.1, 'attack': .6, 'attack_type': 'thunder', 'speed': 4, 'knockback': 20, 'attack_radius': 60, 'notice_radius': 350}, 'spirit': {'id': 3, 'health': .1, 'exp': 1.1, 'attack': .6, 'attack_type': 'thunder', 'speed': 4, 'knockback': 20, 'attack_radius': 60, 'notice_radius': 350},
'bamboo': {'id': 4, 'health': .07, 'exp': 1.2, 'attack': .4, 'attack_type': 'leaf_attack', 'speed': 3, 'knockback': 20, 'attack_radius': 50, 'notice_radius': 300}} 'bamboo': {'id': 4, 'health': .07, 'exp': 1.2, 'attack': .2, 'attack_type': 'leaf_attack', 'speed': 3, 'knockback': 20, 'attack_radius': 50, 'notice_radius': 300}}

View file

@ -120,9 +120,9 @@ class Player(pygame.sprite.Sprite):
self.reward_features = [ self.reward_features = [
self.stats.exp, self.stats.exp,
np.exp(-nearest_dist**2), 2*np.exp(-nearest_dist**2),
np.exp(-nearest_enemy.stats.health**2), np.exp(-nearest_enemy.stats.health),
-np.exp(-self.stats.health) -np.exp(-self.stats.health**2)
] ]
self.state_features = [ self.state_features = [
@ -158,20 +158,21 @@ class Player(pygame.sprite.Sprite):
self.num_features = len(self.state_features) self.num_features = len(self.state_features)
def setup_agent(self): def setup_agent(self):
print(f"Initializing Agent {self.player_id} ...")
self.agent = Agent( self.agent = Agent(
input_dims=len(self.state_features), input_dims=len(self.state_features),
n_actions=len(self._input.possible_actions), n_actions=len(self._input.possible_actions),
batch_size=5, batch_size=5,
n_epochs=4) n_epochs=4)
print(f" Agent initialized using {self.agent.actor.device}. Attempting to load models ...")
try: try:
self.agent.load_models() self.agent.load_models(actr_chkpt = f"player_{self.player_id}_actor", crtc_chkpt = f"player_{self.player_id}_critic")
print("Models loaded ...\n")
except FileNotFoundError: except FileNotFoundError:
print("FileNotFoundError for agent.load_model().\ print("FileNotFound for agent. Skipping loading...\n")
Skipping loading...")
def is_dead(self): def is_dead(self):
if self.stats.health <= 0: if self.stats.health <= 0:
self.stats.exp = max(0, self.stats.exp - .5)
return True return True
else: else:
return False return False
@ -202,9 +203,6 @@ class Player(pygame.sprite.Sprite):
self.get_current_state() self.get_current_state()
if self.is_dead():
self.agent.learn()
# Refresh objects based on input # Refresh objects based on input
self.status = self._input.status self.status = self._input.status
@ -218,3 +216,6 @@ class Player(pygame.sprite.Sprite):
self.stats.health_recovery() self.stats.health_recovery()
self.stats.energy_recovery() self.stats.energy_recovery()
self._input.cooldowns(self._input.combat.vulnerable) self._input.cooldowns(self._input.combat.vulnerable)
if self.is_dead():
self.stats.exp = max(-1, self.stats.exp - .5)

View file

@ -12,7 +12,7 @@ class Game:
pygame.init() pygame.init()
self.screen = pygame.display.set_mode( self.screen = pygame.display.set_mode(
(WIDTH, HEIGHT)) (WIDTH, HEIGHT), pygame.HIDDEN)
pygame.display.set_caption('Pneuma') pygame.display.set_caption('Pneuma')
@ -22,8 +22,11 @@ class Game:
self.level = Level() self.level = Level()
self.max_num_players = len(self.level.player_sprites)
def calc_score(self): def calc_score(self):
self.scores = [0 for _ in range(len(self.level.player_sprites))]
self.scores = [0 for _ in range(self.max_num_players)]
for player in self.level.player_sprites: for player in self.level.player_sprites:
self.scores[player.player_id] = player.stats.exp self.scores[player.player_id] = player.stats.exp
@ -39,7 +42,7 @@ class Game:
self.screen.fill(WATER_COLOR) self.screen.fill(WATER_COLOR)
self.level.run(who='observer') self.level.run()
pygame.display.update() pygame.display.update()
self.clock.tick(FPS) self.clock.tick(FPS)

View file

@ -172,17 +172,9 @@ class Level:
if who == 'observer': if who == 'observer':
self.visible_sprites.custom_draw(self.observer) self.visible_sprites.custom_draw(self.observer)
self.ui.display(self.observer) self.ui.display(self.observer)
elif who == 'player':
self.visible_sprites.custom_draw(self.player)
self.ui.display(self.player)
debug('v0.8') debug('v0.8')
for player in self.player_sprites:
if player.is_dead():
self.dead_players[player.player_id] = True
self.done = True if self.dead_players.all() == 1 else False
if not self.game_paused: if not self.game_paused:
# Update the game # Update the game
@ -191,8 +183,16 @@ class Level:
self.get_players_enemies() self.get_players_enemies()
self.get_distance_direction() self.get_distance_direction()
self.apply_damage_to_player()
self.visible_sprites.update() self.visible_sprites.update()
self.apply_damage_to_player()
else: else:
debug('PAUSED') debug('PAUSED')
for player in self.player_sprites:
if player.is_dead():
self.dead_players[player.player_id] = True
self.done = True if self.dead_players.all() == 1 else False

50
main.py
View file

@ -1,3 +1,5 @@
import random
import torch as T
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -8,23 +10,24 @@ from os import environ
environ['PYGAME_HIDE_SUPPORT_PROMPT'] = '1' environ['PYGAME_HIDE_SUPPORT_PROMPT'] = '1'
if __name__ == '__main__': random.seed(1)
np.random.seed(1)
T.manual_seed(1)
n_episodes = 1000 n_episodes = 1000
game_len = 10000 game_len = 10000
figure_file = 'plots/score.png' figure_file = 'plots/score.png'
best_score = 0
avg_score = 0
game = Game() game = Game()
agent_list = [] agent_list = [0 for _ in range(game.max_num_players)]
exp_points_list = []
score_history = np.zeros( score_history = np.zeros(
shape=(len(game.level.player_sprites), n_episodes, )) shape=(game.max_num_players, n_episodes))
best_score = np.zeros(len(game.level.player_sprites)) best_score = np.zeros(game.max_num_players)
avg_score = np.zeros(len(game.level.player_sprites)) avg_score = np.zeros(game.max_num_players)
for i in tqdm(range(n_episodes)): for i in tqdm(range(n_episodes)):
# TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste) # TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste)
if i != 0: if i != 0:
@ -35,7 +38,7 @@ if __name__ == '__main__':
player.agent = agent_list[player.player_id] player.agent = agent_list[player.player_id]
player.stats.exp = score_history[player.player_id][i-1] player.stats.exp = score_history[player.player_id][i-1]
agent_list = [] agent_list = [0 for _ in range(game.max_num_players)]
for j in range(game_len): for j in range(game_len):
if not game.level.done: if not game.level.done:
@ -43,30 +46,35 @@ if __name__ == '__main__':
game.run() game.run()
game.calc_score() game.calc_score()
if (j == game_len-1 or game.level.done) and game.level.enemy_sprites != []:
for player in game.level.player_sprites: for player in game.level.player_sprites:
for enemy in game.level.enemy_sprites: if player.is_dead():
player.stats.exp *= .95 agent_list[player.player_id] = player.agent
player.kill()
# if (j == game_len-1 or game.level.done) and game.level.enemy_sprites != []:
# for player in game.level.player_sprites:
# for enemy in game.level.enemy_sprites:
# player.stats.exp *= .95
else: else:
break break
for player in game.level.player_sprites: for player in game.level.player_sprites:
agent_list.append(player.agent) if not player.is_dead():
agent_list[player.player_id] = player.agent
exp_points = player.stats.exp exp_points = player.stats.exp
score_history[player.player_id][i] = exp_points score_history[player.player_id][i] = exp_points
avg_score[player.player_id] = np.mean( avg_score[player.player_id] = np.mean(
score_history[player.player_id]) score_history[player.player_id])
if avg_score[player.player_id] >= best_score[player.player_id]: if avg_score[player.player_id] > best_score[player.player_id]:
player.agent.save_models()
best_score[player.player_id] = avg_score[player.player_id] best_score[player.player_id] = avg_score[player.player_id]
print(f"Saving models for agent {player.player_id}...")
player.agent.save_models(actr_chkpt = f"player_{player.player_id}_actor", crtc_chkpt = f"player_{player.player_id}_critic")
print("Models saved ...\n")
print( print(
f"\nCumulative score for player {player.player_id}:\ f"\nCumulative score for player {player.player_id}: {score_history[0][i]}\nAverage score for player {player.player_id}: {avg_score[player.player_id]}\nBest score for player {player.player_id}: {best_score[player.player_id]}")
{score_history[0][i]}\
\nAverage score for player {player.player_id}:\
{avg_score[player.player_id]}\
\nBest score for player {player.player_id}:\
{best_score[player.player_id]}")
plt.plot(score_history[0]) plt.plot(score_history[0])

Binary file not shown.

Binary file not shown.

BIN
tmp/ppo/player_2_actor Normal file

Binary file not shown.

BIN
tmp/ppo/player_2_critic Normal file

Binary file not shown.

BIN
tmp/ppo/player_3_actor Normal file

Binary file not shown.

BIN
tmp/ppo/player_3_critic Normal file

Binary file not shown.

BIN
tmp/ppo/player_4_actor Normal file

Binary file not shown.

BIN
tmp/ppo/player_4_critic Normal file

Binary file not shown.