Implemented Maxwell distribution for rewards
This commit is contained in:
parent
7ca5063592
commit
ca29a0e6dc
4 changed files with 16 additions and 11 deletions
Binary file not shown.
Binary file not shown.
|
@ -62,6 +62,7 @@ class Player(pygame.sprite.Sprite):
|
|||
chkpt_dir,
|
||||
no_load=False):
|
||||
|
||||
self.max_num_enemies = len(self.distance_direction_from_enemy)
|
||||
self.get_current_state()
|
||||
self.num_features = len(self.state_features)
|
||||
|
||||
|
@ -158,9 +159,17 @@ class Player(pygame.sprite.Sprite):
|
|||
|
||||
self.reward_features = [
|
||||
self.stats.exp,
|
||||
|
||||
2*np.exp(-nearest_dist**2),
|
||||
np.exp(-nearest_enemy.stats.health),
|
||||
-np.exp(-self.stats.health**2)
|
||||
|
||||
1/(np.exp((nearest_enemy.stats.health -
|
||||
nearest_enemy.stats.monster_info['health'])/nearest_enemy.stats.monster_info['health'])) - 1,
|
||||
|
||||
1/(np.exp((len(self.distance_direction_from_enemy) -
|
||||
self.max_num_enemies)/self.max_num_enemies)) - 1,
|
||||
|
||||
1 - 1/(np.exp((self.stats.health -
|
||||
self.stats.stats['health'])/self.stats.stats['health']))
|
||||
if not self.is_dead() > 0 else -1
|
||||
]
|
||||
|
||||
|
@ -233,7 +242,7 @@ class Player(pygame.sprite.Sprite):
|
|||
self.stats.energy_recovery()
|
||||
|
||||
else:
|
||||
self.stats.exp = max(0, self.stats.exp - .01)
|
||||
self.stats.exp = max(-1, self.stats.exp - .1)
|
||||
|
||||
# Refresh player based on input and animate
|
||||
self.get_status()
|
||||
|
|
12
pneuma.py
12
pneuma.py
|
@ -1,4 +1,4 @@
|
|||
import os
|
||||
import random
|
||||
import argparse
|
||||
import torch as T
|
||||
import numpy as np
|
||||
|
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
|||
|
||||
parser.add_argument('--horizon',
|
||||
type=int,
|
||||
default=200,
|
||||
default=2048,
|
||||
help="The number of steps per update")
|
||||
|
||||
parser.add_argument('--show_pg',
|
||||
|
@ -100,6 +100,7 @@ if __name__ == "__main__":
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
T.manual_seed(args.seed)
|
||||
|
||||
|
@ -219,12 +220,6 @@ if __name__ == "__main__":
|
|||
|
||||
print(f"Models saved to {chkpt_path}")
|
||||
|
||||
else:
|
||||
print(f"\nScore this round for player {player.player_id}:\
|
||||
{player.stats.exp}")
|
||||
|
||||
|
||||
|
||||
# End of training session
|
||||
print("End of episodes.\
|
||||
\nExiting game...")
|
||||
|
@ -264,4 +259,5 @@ if __name__ == "__main__":
|
|||
for total in total_loss:
|
||||
plt.plot(total)
|
||||
plt.savefig(f"{figure_folder}/total_loss.png")
|
||||
|
||||
game.quit()
|
||||
|
|
Loading…
Reference in a new issue