Added more rewards

This commit is contained in:
Vasilis Valatsos 2023-11-19 04:27:47 +01:00
parent 115b2e4151
commit da649ccca8
15 changed files with 135 additions and 165 deletions

0
__init__.py Normal file
View file

View file

@ -1,3 +1,4 @@
import os
import numpy as np
import torch as T
@ -48,6 +49,7 @@ class Agent:
probs = T.squeeze(dist.log_prob(action)).item()
action = T.squeeze(action).item()
value = T.squeeze(value).item()
return action, probs, value
def learn(self):
@ -95,7 +97,6 @@ class Agent:
critic_loss = critic_loss.mean()
total_loss = actor_loss + 0.5*critic_loss
self.actor.optimizer.zero_grad()
self.critic.optimizer.zero_grad()
total_loss.backward()

View file

@ -66,7 +66,7 @@ class ActorNetwork(nn.Module):
nn.Softmax(dim=-1)
)
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
self.device = T.device('cuda:0' if T.cuda.is_available() else (
'mps' if T.backends.mps.is_available() else 'cpu'))
@ -101,7 +101,8 @@ class CriticNetwork(nn.Module):
nn.Linear(fc2_dims, 1)
)
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
self.device = T.device('cuda:0' if T.cuda.is_available() else (
'mps' if T.backends.mps.is_available() else 'cpu'))
@ -109,7 +110,6 @@ class CriticNetwork(nn.Module):
def forward(self, state):
value = self.critic(state)
return value
def save_checkpoint(self):

Binary file not shown.

Binary file not shown.

View file

@ -18,6 +18,8 @@ class MagicPlayer:
'flame': pygame.mixer.Sound(f'{asset_path}/audio/flame.wav')
}
self.sounds['flame'].set_volume(0)
def heal(self, player, strength, cost, groups):
if player.stats.energy >= cost:
self.sounds['heal'].play()

View file

@ -10,7 +10,6 @@ from .combat import CombatHandler
class InputHandler:
# , status):
def __init__(self, sprite_type, animation_player, ai_controller=False):
self.status = 'down'
self.sprite_type = sprite_type

View file

@ -22,6 +22,6 @@ class AudioHandler:
self.death_sound = pygame.mixer.Sound(
f'{asset_path}/death.wav')
self.hit_sound = pygame.mixer.Sound(f'{asset_path}/hit.wav')
self.death_sound.set_volume(0.2)
self.hit_sound.set_volume(0.2)
self.attack_sound.set_volume(0.2)
self.death_sound.set_volume(0)
self.hit_sound.set_volume(0)
self.attack_sound.set_volume(0)

View file

@ -37,7 +37,7 @@ class CombatHandler:
self.weapon_attack_sound = pygame.mixer.Sound(
f"{asset_path}/sword.wav")
self.weapon_attack_sound.set_volume(0.2)
self.weapon_attack_sound.set_volume(0)
def create_attack_sprite(self, player):
self.current_attack = Weapon(

View file

@ -45,6 +45,12 @@ class Player(pygame.sprite.Sprite):
self.distance_direction_from_enemy = None
# Setup AI
self.score = 0
self.learn_iters = 0
self.n_steps = 0
self.N = 20
def get_status(self):
if self._input.movement.direction.x == 0 and self._input.movement.direction.y == 0:
if 'idle' not in self.status and 'attack' not in self.status:
@ -92,34 +98,40 @@ class Player(pygame.sprite.Sprite):
def get_current_state(self):
if self.distance_direction_from_enemy != []:
sorted_distances = sorted(
self.distance_direction_from_enemy, key=lambda x: x[0])
else:
sorted_distances = np.zeros(self.num_features)
nearest_dist, _, nearest_enemy = sorted_distances[0]
self.action_features = [self._input.action]
self.reward_features = [self.stats.exp]
self.reward_features = [self.stats.exp,
np.exp(-(nearest_dist)),
np.exp(-(nearest_enemy.stats.health)),
- np.exp(self.stats.health)
]
self.state_features = [
self.stats.role_id,
self.rect.center[0],
self.rect.center[1],
self.stats.health,
self.stats.energy,
self.stats.attack,
self.stats.magic,
self.stats.speed,
int(self._input.combat.vulnerable),
int(self._input.can_move),
int(self._input.attacking),
int(self._input.can_rotate_weapon),
int(self._input.can_swap_magic)
self.stats.speed
]
enemy_states = []
for distance, direction, enemy in self.distance_direction_from_enemy:
for distance, direction, enemy in sorted_distances[:5]:
enemy_states.extend([
distance,
direction[0],
direction[1],
enemy.stats.monster_id,
0 if enemy.animation.status == "idle" else (
1 if enemy.animation.status == "move" else 2),
enemy.stats.health,
enemy.stats.attack,
enemy.stats.speed,
@ -127,22 +139,37 @@ class Player(pygame.sprite.Sprite):
enemy.stats.attack_radius,
enemy.stats.notice_radius
])
self.state_features.extend(enemy_states)
def setup_agent(self):
# Setup AI
self.get_current_state()
self.agent = Agent(
input_dims=len(self.state_features), n_actions=len(self._input.possible_actions), batch_size=5, n_epochs=4)
self.score = 0
self.learn_iters = 0
if hasattr(self, 'num_features'):
while len(self.state_features) < self.num_features:
self.state_features.append(0)
self.n_steps = 0
self.N = 20
self.state_features = np.array(self.state_features)
min_feat = np.min(self.state_features)
max_feat = np.max(self.state_features)
self.state_features = (self.state_features -
min_feat) / (max_feat-min_feat)
def get_max_num_states(self):
self.get_current_state()
self.num_features = len(self.state_features)
def setup_agent(self):
self.agent = Agent(
input_dims=len(self.state_features),
n_actions=len(self._input.possible_actions),
batch_size=5,
n_epochs=4)
try:
self.agent.load_models()
except FileNotFoundError as e:
print(f"{e}. Skipping loading...")
def is_dead(self):
if self.stats.health == 0:
self.stats.exp = -10
self.stats.exp = -100
return True
else:
return False
@ -157,8 +184,12 @@ class Player(pygame.sprite.Sprite):
self.n_steps += 1
# Apply chosen action
self._input.check_input(action, self.stats.speed, self.animation.hitbox,
self.obstacle_sprites, self.animation.rect, self)
self._input.check_input(action,
self.stats.speed,
self.animation.hitbox,
self.obstacle_sprites,
self.animation.rect,
self)
self.done = self.is_dead()

41
game.py Normal file
View file

@ -0,0 +1,41 @@
import pygame
import sys
from level.level import Level
from configs.system.window_config import WIDTH, HEIGHT, WATER_COLOR, FPS
class Game:
def __init__(self):
pygame.init()
self.screen = pygame.display.set_mode(
(WIDTH, HEIGHT))
pygame.display.set_caption('Pneuma')
self.clock = pygame.time.Clock()
self.level = Level()
# Sound
main_sound = pygame.mixer.Sound('assets/audio/main.ogg')
main_sound.set_volume(0)
main_sound.play(loops=-1)
def run(self):
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
sys.exit()
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_m:
self.level.toggle_menu()
self.screen.fill(WATER_COLOR)
self.level.run(who='observer')
pygame.display.update()
self.clock.tick(FPS)

View file

@ -20,7 +20,7 @@ from .camera import Camera
class Level:
def __init__(self):
def __init__(self, reset=False):
# General Settings
self.game_paused = False
@ -39,8 +39,13 @@ class Level:
self.create_map()
self.get_players_enemies()
self.get_distance_direction()
for player in self.player_sprites:
player.setup_agent()
if not reset:
for player in self.player_sprites:
player.get_max_num_states()
player.setup_agent()
else:
for player in self.player_sprites:
player.get_max_num_states()
# UI setup
self.ui = UI()
@ -75,84 +80,11 @@ class Level:
random_grass_image = choice(graphics['grass'])
Tile((x, y), [self.visible_sprites, self.obstacle_sprites,
self.attackable_sprites], 'grass', random_grass_image)
if style == 'objects':
surf = graphics['objects'][int(col)]
Tile((x, y), [self.visible_sprites,
self.obstacle_sprites], 'object', surf)
if style == 'entities':
# The numbers represent their IDs in .csv files generated from TILED.
if col == '500':
self.observer = Observer(
(x, y), [self.visible_sprites])
elif col == '400':
# Player Generation
Player(
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'tank', player_id)
player_id += 1
elif col == '401':
# Player Generation
Player(
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'warrior', player_id)
player_id += 1
elif col == '402':
# Player Generation
Player(
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'mage', player_id)
player_id += 1
else:
# Monster Generation
if col == '390':
monster_name = 'bamboo'
elif col == '391':
monster_name = 'spirit'
elif col == '392':
monster_name = 'raccoon'
else:
monster_name = 'squid'
Enemy(monster_name, (x, y), [
self.visible_sprites, self.attackable_sprites], self.visible_sprites, self.obstacle_sprites)
def reset_map(self):
script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join(
script_dir, '..', 'assets')
layouts = {
'boundary': import_csv_layout(f"{asset_path}/map/FloorBlocks.csv"),
'grass': import_csv_layout(f"{asset_path}/map/Grass.csv"),
'objects': import_csv_layout(f"{asset_path}/map/Objects.csv"),
'entities': import_csv_layout(f"{asset_path}/map/Entities.csv")
}
graphics = {
'grass': import_folder(f"{asset_path}/graphics/grass"),
'objects': import_folder(f"{asset_path}/graphics/objects")
}
for style, layout in layouts.items():
for row_index, row in enumerate(layout):
for col_index, col in enumerate(row):
if col != '-1':
x = col_index * TILESIZE
y = row_index * TILESIZE
if style == 'boundary':
Tile((x, y), [self.obstacle_sprites], 'invisible')
if style == 'grass':
random_grass_image = choice(graphics['grass'])
Tile((x, y), [self.visible_sprites, self.obstacle_sprites,
self.attackable_sprites], 'grass', random_grass_image)
if style == 'objects':
surf = graphics['objects'][int(col)]
Tile((x, y), [self.visible_sprites,
self.obstacle_sprites], 'object', surf)
#
# if style == 'objects':
# surf = graphics['objects'][int(col)]
# Tile((x, y), [self.visible_sprites,
# self.obstacle_sprites], 'object', surf)
if style == 'entities':
# The numbers represent their IDs in .csv files generated from TILED.

66
main.py
View file

@ -1,51 +1,12 @@
import sys
import numpy as np
import torch
import pygame
from game import Game
from tqdm import tqdm
from configs.system.window_config import WIDTH, HEIGHT, WATER_COLOR, FPS
from level.level import Level
class Game:
def __init__(self):
pygame.init()
self.screen = pygame.display.set_mode(
(WIDTH, HEIGHT))
pygame.display.set_caption('Pneuma')
self.clock = pygame.time.Clock()
self.level = Level()
# Sound
main_sound = pygame.mixer.Sound('assets/audio/main.ogg')
main_sound.set_volume(0.4)
main_sound.play(loops=-1)
def run(self):
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
sys.exit()
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_m:
self.level.toggle_menu()
self.screen.fill(WATER_COLOR)
self.level.run(who='observer')
pygame.display.update()
self.clock.tick(FPS)
from os import environ
environ['PYGAME_HIDE_SUPPORT_PROMPT'] = '1'
if __name__ == '__main__':
n_games = 300
n_episodes = 1000
figure_file = 'plots/score.png'
score_history = []
@ -54,22 +15,24 @@ if __name__ == '__main__':
agent_list = []
game_len = 10000
game_len = 5000
game = Game()
for i in tqdm(range(n_games)):
for i in tqdm(range(n_episodes)):
# TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste)
game.level.__init__()
if i != 0:
game.level.__init__(reset=True)
# TODO: Make game.level.reset_map() so we don't pull out and load the agent every time (There is -definitevly- a better way)
for player in game.level.player_sprites:
for player_id, agent in agent_list:
if player.player_id == player_id:
player.agent = agent
agent_list = []
done = False
score = 0
for _ in range(game_len):
for _ in tqdm(range(game_len)):
if not game.level.done:
game.run()
else:
@ -77,13 +40,14 @@ if __name__ == '__main__':
for player in game.level.player_sprites:
agent_list.append((player.player_id, player.agent))
if i == n_games-1 and game.level.enemy_sprites != []:
if i == n_episodes-1 and game.level.enemy_sprites != []:
for player in game.level.player_sprites:
player.stats.exp -= 5
for enemy in game.level.enemy_sprites:
player.stats.exp -= 5
player.update()
for player in game.level.player_sprites:
player.agent.save_models()
for player in game.level.player_sprites:
player.agent.save_models()
# TODO: Make it so that scores appear here for each player
# score_history.append(game.level.player.score)

Binary file not shown.

Binary file not shown.