Fixed errors for MARL
This commit is contained in:
parent
1a6ed25673
commit
8809c1b06c
15 changed files with 108 additions and 106 deletions
|
@ -15,28 +15,20 @@ class Agent:
|
||||||
self.n_epochs = n_epochs
|
self.n_epochs = n_epochs
|
||||||
self.gae_lambda = gae_lambda
|
self.gae_lambda = gae_lambda
|
||||||
|
|
||||||
print("Preparing Actor model...")
|
|
||||||
self.actor = ActorNetwork(input_dims, n_actions, alpha)
|
self.actor = ActorNetwork(input_dims, n_actions, alpha)
|
||||||
print(f"Actor network activated using {self.actor.device}")
|
|
||||||
print("\nPreparing Critic model...")
|
|
||||||
self.critic = CriticNetwork(input_dims, alpha)
|
self.critic = CriticNetwork(input_dims, alpha)
|
||||||
print(f"Critic network activated using {self.critic.device}")
|
|
||||||
self.memory = PPOMemory(batch_size)
|
self.memory = PPOMemory(batch_size)
|
||||||
|
|
||||||
def remember(self, state, action, probs, vals, reward, done):
|
def remember(self, state, action, probs, vals, reward, done):
|
||||||
self.memory.store_memory(state, action, probs, vals, reward, done)
|
self.memory.store_memory(state, action, probs, vals, reward, done)
|
||||||
|
|
||||||
def save_models(self):
|
def save_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'):
|
||||||
print('... saving models ...')
|
self.actor.save_checkpoint(actr_chkpt)
|
||||||
self.actor.save_checkpoint()
|
self.critic.save_checkpoint(crtc_chkpt)
|
||||||
self.critic.save_checkpoint()
|
|
||||||
print('... done ...')
|
|
||||||
|
|
||||||
def load_models(self):
|
def load_models(self, actr_chkpt = 'actor_ppo', crtc_chkpt = 'critic_ppo'):
|
||||||
print('... loading models ...')
|
self.actor.load_checkpoint(actr_chkpt)
|
||||||
self.actor.load_checkpoint()
|
self.critic.load_checkpoint(crtc_chkpt)
|
||||||
self.critic.load_checkpoint()
|
|
||||||
print('.. done ...')
|
|
||||||
|
|
||||||
def choose_action(self, observation):
|
def choose_action(self, observation):
|
||||||
state = T.tensor(observation, dtype=T.float).to(self.actor.device)
|
state = T.tensor(observation, dtype=T.float).to(self.actor.device)
|
||||||
|
@ -56,7 +48,7 @@ class Agent:
|
||||||
state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.generate_batches()
|
state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.generate_batches()
|
||||||
|
|
||||||
values = vals_arr
|
values = vals_arr
|
||||||
advantage = np.zeros(len(reward_arr), dtype=np.float32)
|
advantage = np.zeros(len(reward_arr), dtype=np.float64)
|
||||||
|
|
||||||
for t in range(len(reward_arr)-1):
|
for t in range(len(reward_arr)-1):
|
||||||
discount = 1
|
discount = 1
|
||||||
|
|
|
@ -55,7 +55,7 @@ class ActorNetwork(nn.Module):
|
||||||
def __init__(self, input_dim, output_dim, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
|
def __init__(self, input_dim, output_dim, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
|
||||||
super(ActorNetwork, self).__init__()
|
super(ActorNetwork, self).__init__()
|
||||||
|
|
||||||
self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
|
self.chkpt_dir = chkpt_dir
|
||||||
|
|
||||||
self.actor = nn.Sequential(
|
self.actor = nn.Sequential(
|
||||||
nn.Linear(input_dim, fc1_dims),
|
nn.Linear(input_dim, fc1_dims),
|
||||||
|
@ -68,8 +68,7 @@ class ActorNetwork(nn.Module):
|
||||||
|
|
||||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
|
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
|
||||||
|
|
||||||
self.device = T.device('cuda:0' if T.cuda.is_available() else (
|
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
|
||||||
'mps' if T.backends.mps.is_available() else 'cpu'))
|
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
|
@ -79,11 +78,11 @@ class ActorNetwork(nn.Module):
|
||||||
|
|
||||||
return dist
|
return dist
|
||||||
|
|
||||||
def save_checkpoint(self):
|
def save_checkpoint(self, filename = 'actor_torch_ppo'):
|
||||||
T.save(self.state_dict(), self.checkpoint_file)
|
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
|
||||||
|
|
||||||
def load_checkpoint(self):
|
def load_checkpoint(self, filename = 'actor_torch_ppo'):
|
||||||
self.load_state_dict(T.load(self.checkpoint_file))
|
self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename)))
|
||||||
|
|
||||||
|
|
||||||
class CriticNetwork(nn.Module):
|
class CriticNetwork(nn.Module):
|
||||||
|
@ -91,7 +90,7 @@ class CriticNetwork(nn.Module):
|
||||||
def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
|
def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
|
||||||
super(CriticNetwork, self).__init__()
|
super(CriticNetwork, self).__init__()
|
||||||
|
|
||||||
self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
|
self.chkpt_dir = chkpt_dir
|
||||||
|
|
||||||
self.critic = nn.Sequential(
|
self.critic = nn.Sequential(
|
||||||
nn.Linear(input_dims, fc1_dims),
|
nn.Linear(input_dims, fc1_dims),
|
||||||
|
@ -103,8 +102,7 @@ class CriticNetwork(nn.Module):
|
||||||
|
|
||||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
|
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
|
||||||
|
|
||||||
self.device = T.device('cuda:0' if T.cuda.is_available() else (
|
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
|
||||||
'mps' if T.backends.mps.is_available() else 'cpu'))
|
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
|
@ -112,8 +110,8 @@ class CriticNetwork(nn.Module):
|
||||||
value = self.critic(state)
|
value = self.critic(state)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def save_checkpoint(self):
|
def save_checkpoint(self, filename = 'critic_torch_ppo'):
|
||||||
T.save(self.state_dict(), self.checkpoint_file)
|
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
|
||||||
|
|
||||||
def load_checkpoint(self):
|
def load_checkpoint(self, filename = 'critic_torch_ppo'):
|
||||||
self.load_state_dict(T.load(self.checkpoint_file))
|
self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename)))
|
||||||
|
|
|
@ -9,4 +9,4 @@ monster_data = {
|
||||||
'squid': {'id': 1, 'health': .1, 'exp': 1, 'attack': .5, 'attack_type': 'slash', 'speed': 3, 'knockback': 20, 'attack_radius': 80, 'notice_radius': 360},
|
'squid': {'id': 1, 'health': .1, 'exp': 1, 'attack': .5, 'attack_type': 'slash', 'speed': 3, 'knockback': 20, 'attack_radius': 80, 'notice_radius': 360},
|
||||||
'raccoon': {'id': 2, 'health': .3, 'exp': 2.5, 'attack': .8, 'attack_type': 'claw', 'speed': 2, 'knockback': 20, 'attack_radius': 120, 'notice_radius': 400},
|
'raccoon': {'id': 2, 'health': .3, 'exp': 2.5, 'attack': .8, 'attack_type': 'claw', 'speed': 2, 'knockback': 20, 'attack_radius': 120, 'notice_radius': 400},
|
||||||
'spirit': {'id': 3, 'health': .1, 'exp': 1.1, 'attack': .6, 'attack_type': 'thunder', 'speed': 4, 'knockback': 20, 'attack_radius': 60, 'notice_radius': 350},
|
'spirit': {'id': 3, 'health': .1, 'exp': 1.1, 'attack': .6, 'attack_type': 'thunder', 'speed': 4, 'knockback': 20, 'attack_radius': 60, 'notice_radius': 350},
|
||||||
'bamboo': {'id': 4, 'health': .07, 'exp': 1.2, 'attack': .4, 'attack_type': 'leaf_attack', 'speed': 3, 'knockback': 20, 'attack_radius': 50, 'notice_radius': 300}}
|
'bamboo': {'id': 4, 'health': .07, 'exp': 1.2, 'attack': .2, 'attack_type': 'leaf_attack', 'speed': 3, 'knockback': 20, 'attack_radius': 50, 'notice_radius': 300}}
|
||||||
|
|
|
@ -120,9 +120,9 @@ class Player(pygame.sprite.Sprite):
|
||||||
|
|
||||||
self.reward_features = [
|
self.reward_features = [
|
||||||
self.stats.exp,
|
self.stats.exp,
|
||||||
np.exp(-nearest_dist**2),
|
2*np.exp(-nearest_dist**2),
|
||||||
np.exp(-nearest_enemy.stats.health**2),
|
np.exp(-nearest_enemy.stats.health),
|
||||||
-np.exp(-self.stats.health)
|
-np.exp(-self.stats.health**2)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.state_features = [
|
self.state_features = [
|
||||||
|
@ -158,20 +158,21 @@ class Player(pygame.sprite.Sprite):
|
||||||
self.num_features = len(self.state_features)
|
self.num_features = len(self.state_features)
|
||||||
|
|
||||||
def setup_agent(self):
|
def setup_agent(self):
|
||||||
|
print(f"Initializing Agent {self.player_id} ...")
|
||||||
self.agent = Agent(
|
self.agent = Agent(
|
||||||
input_dims=len(self.state_features),
|
input_dims=len(self.state_features),
|
||||||
n_actions=len(self._input.possible_actions),
|
n_actions=len(self._input.possible_actions),
|
||||||
batch_size=5,
|
batch_size=5,
|
||||||
n_epochs=4)
|
n_epochs=4)
|
||||||
|
print(f" Agent initialized using {self.agent.actor.device}. Attempting to load models ...")
|
||||||
try:
|
try:
|
||||||
self.agent.load_models()
|
self.agent.load_models(actr_chkpt = f"player_{self.player_id}_actor", crtc_chkpt = f"player_{self.player_id}_critic")
|
||||||
|
print("Models loaded ...\n")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print("FileNotFoundError for agent.load_model().\
|
print("FileNotFound for agent. Skipping loading...\n")
|
||||||
Skipping loading...")
|
|
||||||
|
|
||||||
def is_dead(self):
|
def is_dead(self):
|
||||||
if self.stats.health <= 0:
|
if self.stats.health <= 0:
|
||||||
self.stats.exp = max(0, self.stats.exp - .5)
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
@ -202,9 +203,6 @@ class Player(pygame.sprite.Sprite):
|
||||||
|
|
||||||
self.get_current_state()
|
self.get_current_state()
|
||||||
|
|
||||||
if self.is_dead():
|
|
||||||
self.agent.learn()
|
|
||||||
|
|
||||||
# Refresh objects based on input
|
# Refresh objects based on input
|
||||||
self.status = self._input.status
|
self.status = self._input.status
|
||||||
|
|
||||||
|
@ -218,3 +216,6 @@ class Player(pygame.sprite.Sprite):
|
||||||
self.stats.health_recovery()
|
self.stats.health_recovery()
|
||||||
self.stats.energy_recovery()
|
self.stats.energy_recovery()
|
||||||
self._input.cooldowns(self._input.combat.vulnerable)
|
self._input.cooldowns(self._input.combat.vulnerable)
|
||||||
|
|
||||||
|
if self.is_dead():
|
||||||
|
self.stats.exp = max(-1, self.stats.exp - .5)
|
||||||
|
|
9
game.py
9
game.py
|
@ -12,7 +12,7 @@ class Game:
|
||||||
pygame.init()
|
pygame.init()
|
||||||
|
|
||||||
self.screen = pygame.display.set_mode(
|
self.screen = pygame.display.set_mode(
|
||||||
(WIDTH, HEIGHT))
|
(WIDTH, HEIGHT), pygame.HIDDEN)
|
||||||
|
|
||||||
pygame.display.set_caption('Pneuma')
|
pygame.display.set_caption('Pneuma')
|
||||||
|
|
||||||
|
@ -22,8 +22,11 @@ class Game:
|
||||||
|
|
||||||
self.level = Level()
|
self.level = Level()
|
||||||
|
|
||||||
|
self.max_num_players = len(self.level.player_sprites)
|
||||||
|
|
||||||
def calc_score(self):
|
def calc_score(self):
|
||||||
self.scores = [0 for _ in range(len(self.level.player_sprites))]
|
|
||||||
|
self.scores = [0 for _ in range(self.max_num_players)]
|
||||||
|
|
||||||
for player in self.level.player_sprites:
|
for player in self.level.player_sprites:
|
||||||
self.scores[player.player_id] = player.stats.exp
|
self.scores[player.player_id] = player.stats.exp
|
||||||
|
@ -39,7 +42,7 @@ class Game:
|
||||||
|
|
||||||
self.screen.fill(WATER_COLOR)
|
self.screen.fill(WATER_COLOR)
|
||||||
|
|
||||||
self.level.run(who='observer')
|
self.level.run()
|
||||||
|
|
||||||
pygame.display.update()
|
pygame.display.update()
|
||||||
self.clock.tick(FPS)
|
self.clock.tick(FPS)
|
||||||
|
|
|
@ -172,17 +172,9 @@ class Level:
|
||||||
if who == 'observer':
|
if who == 'observer':
|
||||||
self.visible_sprites.custom_draw(self.observer)
|
self.visible_sprites.custom_draw(self.observer)
|
||||||
self.ui.display(self.observer)
|
self.ui.display(self.observer)
|
||||||
elif who == 'player':
|
|
||||||
self.visible_sprites.custom_draw(self.player)
|
|
||||||
self.ui.display(self.player)
|
|
||||||
|
|
||||||
debug('v0.8')
|
debug('v0.8')
|
||||||
|
|
||||||
for player in self.player_sprites:
|
|
||||||
if player.is_dead():
|
|
||||||
self.dead_players[player.player_id] = True
|
|
||||||
|
|
||||||
self.done = True if self.dead_players.all() == 1 else False
|
|
||||||
|
|
||||||
if not self.game_paused:
|
if not self.game_paused:
|
||||||
# Update the game
|
# Update the game
|
||||||
|
@ -191,8 +183,16 @@ class Level:
|
||||||
|
|
||||||
self.get_players_enemies()
|
self.get_players_enemies()
|
||||||
self.get_distance_direction()
|
self.get_distance_direction()
|
||||||
self.apply_damage_to_player()
|
|
||||||
self.visible_sprites.update()
|
self.visible_sprites.update()
|
||||||
|
self.apply_damage_to_player()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
debug('PAUSED')
|
debug('PAUSED')
|
||||||
|
|
||||||
|
for player in self.player_sprites:
|
||||||
|
if player.is_dead():
|
||||||
|
self.dead_players[player.player_id] = True
|
||||||
|
|
||||||
|
self.done = True if self.dead_players.all() == 1 else False
|
||||||
|
|
||||||
|
|
||||||
|
|
68
main.py
68
main.py
|
@ -1,3 +1,5 @@
|
||||||
|
import random
|
||||||
|
import torch as T
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
@ -8,24 +10,25 @@ from os import environ
|
||||||
environ['PYGAME_HIDE_SUPPORT_PROMPT'] = '1'
|
environ['PYGAME_HIDE_SUPPORT_PROMPT'] = '1'
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
random.seed(1)
|
||||||
|
np.random.seed(1)
|
||||||
|
T.manual_seed(1)
|
||||||
|
|
||||||
n_episodes = 1000
|
n_episodes = 1000
|
||||||
game_len = 10000
|
game_len = 10000
|
||||||
|
|
||||||
figure_file = 'plots/score.png'
|
figure_file = 'plots/score.png'
|
||||||
best_score = 0
|
|
||||||
avg_score = 0
|
|
||||||
|
|
||||||
game = Game()
|
game = Game()
|
||||||
|
|
||||||
agent_list = []
|
agent_list = [0 for _ in range(game.max_num_players)]
|
||||||
exp_points_list = []
|
|
||||||
score_history = np.zeros(
|
score_history = np.zeros(
|
||||||
shape=(len(game.level.player_sprites), n_episodes, ))
|
shape=(game.max_num_players, n_episodes))
|
||||||
best_score = np.zeros(len(game.level.player_sprites))
|
best_score = np.zeros(game.max_num_players)
|
||||||
avg_score = np.zeros(len(game.level.player_sprites))
|
avg_score = np.zeros(game.max_num_players)
|
||||||
for i in tqdm(range(n_episodes)):
|
|
||||||
|
for i in tqdm(range(n_episodes)):
|
||||||
# TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste)
|
# TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste)
|
||||||
if i != 0:
|
if i != 0:
|
||||||
game.level.__init__(reset=True)
|
game.level.__init__(reset=True)
|
||||||
|
@ -35,7 +38,7 @@ if __name__ == '__main__':
|
||||||
player.agent = agent_list[player.player_id]
|
player.agent = agent_list[player.player_id]
|
||||||
player.stats.exp = score_history[player.player_id][i-1]
|
player.stats.exp = score_history[player.player_id][i-1]
|
||||||
|
|
||||||
agent_list = []
|
agent_list = [0 for _ in range(game.max_num_players)]
|
||||||
|
|
||||||
for j in range(game_len):
|
for j in range(game_len):
|
||||||
if not game.level.done:
|
if not game.level.done:
|
||||||
|
@ -43,33 +46,38 @@ if __name__ == '__main__':
|
||||||
game.run()
|
game.run()
|
||||||
game.calc_score()
|
game.calc_score()
|
||||||
|
|
||||||
if (j == game_len-1 or game.level.done) and game.level.enemy_sprites != []:
|
|
||||||
for player in game.level.player_sprites:
|
for player in game.level.player_sprites:
|
||||||
for enemy in game.level.enemy_sprites:
|
if player.is_dead():
|
||||||
player.stats.exp *= .95
|
agent_list[player.player_id] = player.agent
|
||||||
|
player.kill()
|
||||||
|
|
||||||
|
# if (j == game_len-1 or game.level.done) and game.level.enemy_sprites != []:
|
||||||
|
# for player in game.level.player_sprites:
|
||||||
|
# for enemy in game.level.enemy_sprites:
|
||||||
|
# player.stats.exp *= .95
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
for player in game.level.player_sprites:
|
for player in game.level.player_sprites:
|
||||||
agent_list.append(player.agent)
|
if not player.is_dead():
|
||||||
|
agent_list[player.player_id] = player.agent
|
||||||
exp_points = player.stats.exp
|
exp_points = player.stats.exp
|
||||||
score_history[player.player_id][i] = exp_points
|
score_history[player.player_id][i] = exp_points
|
||||||
avg_score[player.player_id] = np.mean(
|
avg_score[player.player_id] = np.mean(
|
||||||
score_history[player.player_id])
|
score_history[player.player_id])
|
||||||
if avg_score[player.player_id] >= best_score[player.player_id]:
|
if avg_score[player.player_id] > best_score[player.player_id]:
|
||||||
player.agent.save_models()
|
|
||||||
best_score[player.player_id] = avg_score[player.player_id]
|
best_score[player.player_id] = avg_score[player.player_id]
|
||||||
|
print(f"Saving models for agent {player.player_id}...")
|
||||||
|
player.agent.save_models(actr_chkpt = f"player_{player.player_id}_actor", crtc_chkpt = f"player_{player.player_id}_critic")
|
||||||
|
print("Models saved ...\n")
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"\nCumulative score for player {player.player_id}:\
|
f"\nCumulative score for player {player.player_id}: {score_history[0][i]}\nAverage score for player {player.player_id}: {avg_score[player.player_id]}\nBest score for player {player.player_id}: {best_score[player.player_id]}")
|
||||||
{score_history[0][i]}\
|
|
||||||
\nAverage score for player {player.player_id}:\
|
|
||||||
{avg_score[player.player_id]}\
|
|
||||||
\nBest score for player {player.player_id}:\
|
|
||||||
{best_score[player.player_id]}")
|
|
||||||
|
|
||||||
plt.plot(score_history[0])
|
|
||||||
|
|
||||||
game.quit()
|
|
||||||
|
|
||||||
plt.show()
|
plt.plot(score_history[0])
|
||||||
|
|
||||||
|
game.quit()
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
Binary file not shown.
Binary file not shown.
BIN
tmp/ppo/player_2_actor
Normal file
BIN
tmp/ppo/player_2_actor
Normal file
Binary file not shown.
BIN
tmp/ppo/player_2_critic
Normal file
BIN
tmp/ppo/player_2_critic
Normal file
Binary file not shown.
BIN
tmp/ppo/player_3_actor
Normal file
BIN
tmp/ppo/player_3_actor
Normal file
Binary file not shown.
BIN
tmp/ppo/player_3_critic
Normal file
BIN
tmp/ppo/player_3_critic
Normal file
Binary file not shown.
BIN
tmp/ppo/player_4_actor
Normal file
BIN
tmp/ppo/player_4_actor
Normal file
Binary file not shown.
BIN
tmp/ppo/player_4_critic
Normal file
BIN
tmp/ppo/player_4_critic
Normal file
Binary file not shown.
Loading…
Reference in a new issue