Implemented Maxwell distribution for rewards

This commit is contained in:
Vasilis Valatsos 2023-12-04 05:08:41 +01:00
parent 7ca5063592
commit ca29a0e6dc
4 changed files with 16 additions and 11 deletions

Binary file not shown.

Binary file not shown.

View file

@ -62,6 +62,7 @@ class Player(pygame.sprite.Sprite):
chkpt_dir, chkpt_dir,
no_load=False): no_load=False):
self.max_num_enemies = len(self.distance_direction_from_enemy)
self.get_current_state() self.get_current_state()
self.num_features = len(self.state_features) self.num_features = len(self.state_features)
@ -158,9 +159,17 @@ class Player(pygame.sprite.Sprite):
self.reward_features = [ self.reward_features = [
self.stats.exp, self.stats.exp,
2*np.exp(-nearest_dist**2), 2*np.exp(-nearest_dist**2),
np.exp(-nearest_enemy.stats.health),
-np.exp(-self.stats.health**2) 1/(np.exp((nearest_enemy.stats.health -
nearest_enemy.stats.monster_info['health'])/nearest_enemy.stats.monster_info['health'])) - 1,
1/(np.exp((len(self.distance_direction_from_enemy) -
self.max_num_enemies)/self.max_num_enemies)) - 1,
1 - 1/(np.exp((self.stats.health -
self.stats.stats['health'])/self.stats.stats['health']))
if not self.is_dead() > 0 else -1 if not self.is_dead() > 0 else -1
] ]
@ -233,7 +242,7 @@ class Player(pygame.sprite.Sprite):
self.stats.energy_recovery() self.stats.energy_recovery()
else: else:
self.stats.exp = max(0, self.stats.exp - .01) self.stats.exp = max(-1, self.stats.exp - .1)
# Refresh player based on input and animate # Refresh player based on input and animate
self.get_status() self.get_status()

View file

@ -1,4 +1,4 @@
import os import random
import argparse import argparse
import torch as T import torch as T
import numpy as np import numpy as np
@ -55,7 +55,7 @@ if __name__ == "__main__":
parser.add_argument('--horizon', parser.add_argument('--horizon',
type=int, type=int,
default=200, default=2048,
help="The number of steps per update") help="The number of steps per update")
parser.add_argument('--show_pg', parser.add_argument('--show_pg',
@ -100,6 +100,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
T.manual_seed(args.seed) T.manual_seed(args.seed)
@ -219,12 +220,6 @@ if __name__ == "__main__":
print(f"Models saved to {chkpt_path}") print(f"Models saved to {chkpt_path}")
else:
print(f"\nScore this round for player {player.player_id}:\
{player.stats.exp}")
# End of training session # End of training session
print("End of episodes.\ print("End of episodes.\
\nExiting game...") \nExiting game...")
@ -264,4 +259,5 @@ if __name__ == "__main__":
for total in total_loss: for total in total_loss:
plt.plot(total) plt.plot(total)
plt.savefig(f"{figure_folder}/total_loss.png") plt.savefig(f"{figure_folder}/total_loss.png")
game.quit() game.quit()