Implemented Maxwell distribution for rewards
This commit is contained in:
parent
7ca5063592
commit
ca29a0e6dc
4 changed files with 16 additions and 11 deletions
Binary file not shown.
Binary file not shown.
|
@ -62,6 +62,7 @@ class Player(pygame.sprite.Sprite):
|
||||||
chkpt_dir,
|
chkpt_dir,
|
||||||
no_load=False):
|
no_load=False):
|
||||||
|
|
||||||
|
self.max_num_enemies = len(self.distance_direction_from_enemy)
|
||||||
self.get_current_state()
|
self.get_current_state()
|
||||||
self.num_features = len(self.state_features)
|
self.num_features = len(self.state_features)
|
||||||
|
|
||||||
|
@ -158,9 +159,17 @@ class Player(pygame.sprite.Sprite):
|
||||||
|
|
||||||
self.reward_features = [
|
self.reward_features = [
|
||||||
self.stats.exp,
|
self.stats.exp,
|
||||||
|
|
||||||
2*np.exp(-nearest_dist**2),
|
2*np.exp(-nearest_dist**2),
|
||||||
np.exp(-nearest_enemy.stats.health),
|
|
||||||
-np.exp(-self.stats.health**2)
|
1/(np.exp((nearest_enemy.stats.health -
|
||||||
|
nearest_enemy.stats.monster_info['health'])/nearest_enemy.stats.monster_info['health'])) - 1,
|
||||||
|
|
||||||
|
1/(np.exp((len(self.distance_direction_from_enemy) -
|
||||||
|
self.max_num_enemies)/self.max_num_enemies)) - 1,
|
||||||
|
|
||||||
|
1 - 1/(np.exp((self.stats.health -
|
||||||
|
self.stats.stats['health'])/self.stats.stats['health']))
|
||||||
if not self.is_dead() > 0 else -1
|
if not self.is_dead() > 0 else -1
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -233,7 +242,7 @@ class Player(pygame.sprite.Sprite):
|
||||||
self.stats.energy_recovery()
|
self.stats.energy_recovery()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.stats.exp = max(0, self.stats.exp - .01)
|
self.stats.exp = max(-1, self.stats.exp - .1)
|
||||||
|
|
||||||
# Refresh player based on input and animate
|
# Refresh player based on input and animate
|
||||||
self.get_status()
|
self.get_status()
|
||||||
|
|
12
pneuma.py
12
pneuma.py
|
@ -1,4 +1,4 @@
|
||||||
import os
|
import random
|
||||||
import argparse
|
import argparse
|
||||||
import torch as T
|
import torch as T
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
parser.add_argument('--horizon',
|
parser.add_argument('--horizon',
|
||||||
type=int,
|
type=int,
|
||||||
default=200,
|
default=2048,
|
||||||
help="The number of steps per update")
|
help="The number of steps per update")
|
||||||
|
|
||||||
parser.add_argument('--show_pg',
|
parser.add_argument('--show_pg',
|
||||||
|
@ -100,6 +100,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
T.manual_seed(args.seed)
|
T.manual_seed(args.seed)
|
||||||
|
|
||||||
|
@ -219,12 +220,6 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print(f"Models saved to {chkpt_path}")
|
print(f"Models saved to {chkpt_path}")
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"\nScore this round for player {player.player_id}:\
|
|
||||||
{player.stats.exp}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# End of training session
|
# End of training session
|
||||||
print("End of episodes.\
|
print("End of episodes.\
|
||||||
\nExiting game...")
|
\nExiting game...")
|
||||||
|
@ -264,4 +259,5 @@ if __name__ == "__main__":
|
||||||
for total in total_loss:
|
for total in total_loss:
|
||||||
plt.plot(total)
|
plt.plot(total)
|
||||||
plt.savefig(f"{figure_folder}/total_loss.png")
|
plt.savefig(f"{figure_folder}/total_loss.png")
|
||||||
|
|
||||||
game.quit()
|
game.quit()
|
||||||
|
|
Loading…
Reference in a new issue