Implemented Maxwell distribution for rewards

This commit is contained in:
Vasilis Valatsos 2023-12-04 05:08:41 +01:00
parent 7ca5063592
commit ca29a0e6dc
4 changed files with 16 additions and 11 deletions

Binary file not shown.

Binary file not shown.

View file

@ -62,6 +62,7 @@ class Player(pygame.sprite.Sprite):
chkpt_dir,
no_load=False):
self.max_num_enemies = len(self.distance_direction_from_enemy)
self.get_current_state()
self.num_features = len(self.state_features)
@ -158,9 +159,17 @@ class Player(pygame.sprite.Sprite):
self.reward_features = [
self.stats.exp,
2*np.exp(-nearest_dist**2),
np.exp(-nearest_enemy.stats.health),
-np.exp(-self.stats.health**2)
1/(np.exp((nearest_enemy.stats.health -
nearest_enemy.stats.monster_info['health'])/nearest_enemy.stats.monster_info['health'])) - 1,
1/(np.exp((len(self.distance_direction_from_enemy) -
self.max_num_enemies)/self.max_num_enemies)) - 1,
1 - 1/(np.exp((self.stats.health -
self.stats.stats['health'])/self.stats.stats['health']))
if not self.is_dead() > 0 else -1
]
@ -233,7 +242,7 @@ class Player(pygame.sprite.Sprite):
self.stats.energy_recovery()
else:
self.stats.exp = max(0, self.stats.exp - .01)
self.stats.exp = max(-1, self.stats.exp - .1)
# Refresh player based on input and animate
self.get_status()

View file

@ -1,4 +1,4 @@
import os
import random
import argparse
import torch as T
import numpy as np
@ -55,7 +55,7 @@ if __name__ == "__main__":
parser.add_argument('--horizon',
type=int,
default=200,
default=2048,
help="The number of steps per update")
parser.add_argument('--show_pg',
@ -100,6 +100,7 @@ if __name__ == "__main__":
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
T.manual_seed(args.seed)
@ -219,12 +220,6 @@ if __name__ == "__main__":
print(f"Models saved to {chkpt_path}")
else:
print(f"\nScore this round for player {player.player_id}:\
{player.stats.exp}")
# End of training session
print("End of episodes.\
\nExiting game...")
@ -264,4 +259,5 @@ if __name__ == "__main__":
for total in total_loss:
plt.plot(total)
plt.savefig(f"{figure_folder}/total_loss.png")
game.quit()