Added more rewards
This commit is contained in:
parent
115b2e4151
commit
da649ccca8
15 changed files with 135 additions and 165 deletions
0
__init__.py
Normal file
0
__init__.py
Normal file
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch as T
|
||||
|
||||
|
@ -48,6 +49,7 @@ class Agent:
|
|||
probs = T.squeeze(dist.log_prob(action)).item()
|
||||
action = T.squeeze(action).item()
|
||||
value = T.squeeze(value).item()
|
||||
|
||||
return action, probs, value
|
||||
|
||||
def learn(self):
|
||||
|
@ -95,7 +97,6 @@ class Agent:
|
|||
critic_loss = critic_loss.mean()
|
||||
|
||||
total_loss = actor_loss + 0.5*critic_loss
|
||||
|
||||
self.actor.optimizer.zero_grad()
|
||||
self.critic.optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
|
|
|
@ -66,7 +66,7 @@ class ActorNetwork(nn.Module):
|
|||
nn.Softmax(dim=-1)
|
||||
)
|
||||
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
|
||||
|
||||
self.device = T.device('cuda:0' if T.cuda.is_available() else (
|
||||
'mps' if T.backends.mps.is_available() else 'cpu'))
|
||||
|
@ -101,7 +101,8 @@ class CriticNetwork(nn.Module):
|
|||
nn.Linear(fc2_dims, 1)
|
||||
)
|
||||
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
|
||||
|
||||
self.device = T.device('cuda:0' if T.cuda.is_available() else (
|
||||
'mps' if T.backends.mps.is_available() else 'cpu'))
|
||||
|
||||
|
@ -109,7 +110,6 @@ class CriticNetwork(nn.Module):
|
|||
|
||||
def forward(self, state):
|
||||
value = self.critic(state)
|
||||
|
||||
return value
|
||||
|
||||
def save_checkpoint(self):
|
||||
|
|
BIN
agents/ppo/model/attempt_1/actor_torch_ppo
Normal file
BIN
agents/ppo/model/attempt_1/actor_torch_ppo
Normal file
Binary file not shown.
BIN
agents/ppo/model/attempt_1/critic_torch_ppo
Normal file
BIN
agents/ppo/model/attempt_1/critic_torch_ppo
Normal file
Binary file not shown.
|
@ -18,6 +18,8 @@ class MagicPlayer:
|
|||
'flame': pygame.mixer.Sound(f'{asset_path}/audio/flame.wav')
|
||||
}
|
||||
|
||||
self.sounds['flame'].set_volume(0)
|
||||
|
||||
def heal(self, player, strength, cost, groups):
|
||||
if player.stats.energy >= cost:
|
||||
self.sounds['heal'].play()
|
||||
|
|
|
@ -10,7 +10,6 @@ from .combat import CombatHandler
|
|||
|
||||
class InputHandler:
|
||||
|
||||
# , status):
|
||||
def __init__(self, sprite_type, animation_player, ai_controller=False):
|
||||
self.status = 'down'
|
||||
self.sprite_type = sprite_type
|
||||
|
|
|
@ -22,6 +22,6 @@ class AudioHandler:
|
|||
self.death_sound = pygame.mixer.Sound(
|
||||
f'{asset_path}/death.wav')
|
||||
self.hit_sound = pygame.mixer.Sound(f'{asset_path}/hit.wav')
|
||||
self.death_sound.set_volume(0.2)
|
||||
self.hit_sound.set_volume(0.2)
|
||||
self.attack_sound.set_volume(0.2)
|
||||
self.death_sound.set_volume(0)
|
||||
self.hit_sound.set_volume(0)
|
||||
self.attack_sound.set_volume(0)
|
||||
|
|
|
@ -37,7 +37,7 @@ class CombatHandler:
|
|||
|
||||
self.weapon_attack_sound = pygame.mixer.Sound(
|
||||
f"{asset_path}/sword.wav")
|
||||
self.weapon_attack_sound.set_volume(0.2)
|
||||
self.weapon_attack_sound.set_volume(0)
|
||||
|
||||
def create_attack_sprite(self, player):
|
||||
self.current_attack = Weapon(
|
||||
|
|
|
@ -45,6 +45,12 @@ class Player(pygame.sprite.Sprite):
|
|||
|
||||
self.distance_direction_from_enemy = None
|
||||
|
||||
# Setup AI
|
||||
self.score = 0
|
||||
self.learn_iters = 0
|
||||
self.n_steps = 0
|
||||
self.N = 20
|
||||
|
||||
def get_status(self):
|
||||
if self._input.movement.direction.x == 0 and self._input.movement.direction.y == 0:
|
||||
if 'idle' not in self.status and 'attack' not in self.status:
|
||||
|
@ -92,34 +98,40 @@ class Player(pygame.sprite.Sprite):
|
|||
|
||||
def get_current_state(self):
|
||||
|
||||
if self.distance_direction_from_enemy != []:
|
||||
sorted_distances = sorted(
|
||||
self.distance_direction_from_enemy, key=lambda x: x[0])
|
||||
else:
|
||||
sorted_distances = np.zeros(self.num_features)
|
||||
|
||||
nearest_dist, _, nearest_enemy = sorted_distances[0]
|
||||
|
||||
self.action_features = [self._input.action]
|
||||
self.reward_features = [self.stats.exp]
|
||||
|
||||
self.reward_features = [self.stats.exp,
|
||||
np.exp(-(nearest_dist)),
|
||||
np.exp(-(nearest_enemy.stats.health)),
|
||||
- np.exp(self.stats.health)
|
||||
]
|
||||
|
||||
self.state_features = [
|
||||
self.stats.role_id,
|
||||
self.rect.center[0],
|
||||
self.rect.center[1],
|
||||
self.stats.health,
|
||||
self.stats.energy,
|
||||
self.stats.attack,
|
||||
self.stats.magic,
|
||||
self.stats.speed,
|
||||
int(self._input.combat.vulnerable),
|
||||
int(self._input.can_move),
|
||||
int(self._input.attacking),
|
||||
int(self._input.can_rotate_weapon),
|
||||
int(self._input.can_swap_magic)
|
||||
self.stats.speed
|
||||
]
|
||||
|
||||
enemy_states = []
|
||||
|
||||
for distance, direction, enemy in self.distance_direction_from_enemy:
|
||||
for distance, direction, enemy in sorted_distances[:5]:
|
||||
|
||||
enemy_states.extend([
|
||||
distance,
|
||||
direction[0],
|
||||
direction[1],
|
||||
enemy.stats.monster_id,
|
||||
0 if enemy.animation.status == "idle" else (
|
||||
1 if enemy.animation.status == "move" else 2),
|
||||
enemy.stats.health,
|
||||
enemy.stats.attack,
|
||||
enemy.stats.speed,
|
||||
|
@ -127,22 +139,37 @@ class Player(pygame.sprite.Sprite):
|
|||
enemy.stats.attack_radius,
|
||||
enemy.stats.notice_radius
|
||||
])
|
||||
|
||||
self.state_features.extend(enemy_states)
|
||||
|
||||
def setup_agent(self):
|
||||
# Setup AI
|
||||
self.get_current_state()
|
||||
self.agent = Agent(
|
||||
input_dims=len(self.state_features), n_actions=len(self._input.possible_actions), batch_size=5, n_epochs=4)
|
||||
self.score = 0
|
||||
self.learn_iters = 0
|
||||
if hasattr(self, 'num_features'):
|
||||
while len(self.state_features) < self.num_features:
|
||||
self.state_features.append(0)
|
||||
|
||||
self.n_steps = 0
|
||||
self.N = 20
|
||||
self.state_features = np.array(self.state_features)
|
||||
min_feat = np.min(self.state_features)
|
||||
max_feat = np.max(self.state_features)
|
||||
self.state_features = (self.state_features -
|
||||
min_feat) / (max_feat-min_feat)
|
||||
|
||||
def get_max_num_states(self):
|
||||
self.get_current_state()
|
||||
self.num_features = len(self.state_features)
|
||||
|
||||
def setup_agent(self):
|
||||
self.agent = Agent(
|
||||
input_dims=len(self.state_features),
|
||||
n_actions=len(self._input.possible_actions),
|
||||
batch_size=5,
|
||||
n_epochs=4)
|
||||
try:
|
||||
self.agent.load_models()
|
||||
except FileNotFoundError as e:
|
||||
print(f"{e}. Skipping loading...")
|
||||
|
||||
def is_dead(self):
|
||||
if self.stats.health == 0:
|
||||
self.stats.exp = -10
|
||||
self.stats.exp = -100
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
@ -157,8 +184,12 @@ class Player(pygame.sprite.Sprite):
|
|||
|
||||
self.n_steps += 1
|
||||
# Apply chosen action
|
||||
self._input.check_input(action, self.stats.speed, self.animation.hitbox,
|
||||
self.obstacle_sprites, self.animation.rect, self)
|
||||
self._input.check_input(action,
|
||||
self.stats.speed,
|
||||
self.animation.hitbox,
|
||||
self.obstacle_sprites,
|
||||
self.animation.rect,
|
||||
self)
|
||||
|
||||
self.done = self.is_dead()
|
||||
|
||||
|
|
41
game.py
Normal file
41
game.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
import pygame
|
||||
import sys
|
||||
|
||||
from level.level import Level
|
||||
from configs.system.window_config import WIDTH, HEIGHT, WATER_COLOR, FPS
|
||||
|
||||
|
||||
class Game:
|
||||
|
||||
def __init__(self):
|
||||
|
||||
pygame.init()
|
||||
|
||||
self.screen = pygame.display.set_mode(
|
||||
(WIDTH, HEIGHT))
|
||||
pygame.display.set_caption('Pneuma')
|
||||
self.clock = pygame.time.Clock()
|
||||
|
||||
self.level = Level()
|
||||
|
||||
# Sound
|
||||
main_sound = pygame.mixer.Sound('assets/audio/main.ogg')
|
||||
main_sound.set_volume(0)
|
||||
main_sound.play(loops=-1)
|
||||
|
||||
def run(self):
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
sys.exit()
|
||||
if event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_m:
|
||||
self.level.toggle_menu()
|
||||
|
||||
self.screen.fill(WATER_COLOR)
|
||||
|
||||
self.level.run(who='observer')
|
||||
|
||||
pygame.display.update()
|
||||
self.clock.tick(FPS)
|
|
@ -20,7 +20,7 @@ from .camera import Camera
|
|||
|
||||
class Level:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, reset=False):
|
||||
|
||||
# General Settings
|
||||
self.game_paused = False
|
||||
|
@ -39,8 +39,13 @@ class Level:
|
|||
self.create_map()
|
||||
self.get_players_enemies()
|
||||
self.get_distance_direction()
|
||||
if not reset:
|
||||
for player in self.player_sprites:
|
||||
player.get_max_num_states()
|
||||
player.setup_agent()
|
||||
else:
|
||||
for player in self.player_sprites:
|
||||
player.get_max_num_states()
|
||||
|
||||
# UI setup
|
||||
self.ui = UI()
|
||||
|
@ -75,84 +80,11 @@ class Level:
|
|||
random_grass_image = choice(graphics['grass'])
|
||||
Tile((x, y), [self.visible_sprites, self.obstacle_sprites,
|
||||
self.attackable_sprites], 'grass', random_grass_image)
|
||||
|
||||
if style == 'objects':
|
||||
surf = graphics['objects'][int(col)]
|
||||
Tile((x, y), [self.visible_sprites,
|
||||
self.obstacle_sprites], 'object', surf)
|
||||
|
||||
if style == 'entities':
|
||||
# The numbers represent their IDs in .csv files generated from TILED.
|
||||
if col == '500':
|
||||
self.observer = Observer(
|
||||
(x, y), [self.visible_sprites])
|
||||
|
||||
elif col == '400':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'tank', player_id)
|
||||
player_id += 1
|
||||
|
||||
elif col == '401':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'warrior', player_id)
|
||||
player_id += 1
|
||||
|
||||
elif col == '402':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'mage', player_id)
|
||||
player_id += 1
|
||||
|
||||
else:
|
||||
# Monster Generation
|
||||
if col == '390':
|
||||
monster_name = 'bamboo'
|
||||
elif col == '391':
|
||||
monster_name = 'spirit'
|
||||
elif col == '392':
|
||||
monster_name = 'raccoon'
|
||||
else:
|
||||
monster_name = 'squid'
|
||||
|
||||
Enemy(monster_name, (x, y), [
|
||||
self.visible_sprites, self.attackable_sprites], self.visible_sprites, self.obstacle_sprites)
|
||||
|
||||
def reset_map(self):
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '..', 'assets')
|
||||
layouts = {
|
||||
'boundary': import_csv_layout(f"{asset_path}/map/FloorBlocks.csv"),
|
||||
'grass': import_csv_layout(f"{asset_path}/map/Grass.csv"),
|
||||
'objects': import_csv_layout(f"{asset_path}/map/Objects.csv"),
|
||||
'entities': import_csv_layout(f"{asset_path}/map/Entities.csv")
|
||||
}
|
||||
|
||||
graphics = {
|
||||
'grass': import_folder(f"{asset_path}/graphics/grass"),
|
||||
'objects': import_folder(f"{asset_path}/graphics/objects")
|
||||
}
|
||||
|
||||
for style, layout in layouts.items():
|
||||
for row_index, row in enumerate(layout):
|
||||
for col_index, col in enumerate(row):
|
||||
if col != '-1':
|
||||
x = col_index * TILESIZE
|
||||
y = row_index * TILESIZE
|
||||
if style == 'boundary':
|
||||
Tile((x, y), [self.obstacle_sprites], 'invisible')
|
||||
|
||||
if style == 'grass':
|
||||
random_grass_image = choice(graphics['grass'])
|
||||
Tile((x, y), [self.visible_sprites, self.obstacle_sprites,
|
||||
self.attackable_sprites], 'grass', random_grass_image)
|
||||
|
||||
if style == 'objects':
|
||||
surf = graphics['objects'][int(col)]
|
||||
Tile((x, y), [self.visible_sprites,
|
||||
self.obstacle_sprites], 'object', surf)
|
||||
#
|
||||
# if style == 'objects':
|
||||
# surf = graphics['objects'][int(col)]
|
||||
# Tile((x, y), [self.visible_sprites,
|
||||
# self.obstacle_sprites], 'object', surf)
|
||||
|
||||
if style == 'entities':
|
||||
# The numbers represent their IDs in .csv files generated from TILED.
|
||||
|
|
60
main.py
60
main.py
|
@ -1,51 +1,12 @@
|
|||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import pygame
|
||||
from game import Game
|
||||
from tqdm import tqdm
|
||||
from configs.system.window_config import WIDTH, HEIGHT, WATER_COLOR, FPS
|
||||
|
||||
from level.level import Level
|
||||
|
||||
|
||||
class Game:
|
||||
|
||||
def __init__(self):
|
||||
|
||||
pygame.init()
|
||||
|
||||
self.screen = pygame.display.set_mode(
|
||||
(WIDTH, HEIGHT))
|
||||
pygame.display.set_caption('Pneuma')
|
||||
self.clock = pygame.time.Clock()
|
||||
|
||||
self.level = Level()
|
||||
|
||||
# Sound
|
||||
main_sound = pygame.mixer.Sound('assets/audio/main.ogg')
|
||||
main_sound.set_volume(0.4)
|
||||
main_sound.play(loops=-1)
|
||||
|
||||
def run(self):
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
sys.exit()
|
||||
if event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_m:
|
||||
self.level.toggle_menu()
|
||||
|
||||
self.screen.fill(WATER_COLOR)
|
||||
|
||||
self.level.run(who='observer')
|
||||
|
||||
pygame.display.update()
|
||||
self.clock.tick(FPS)
|
||||
from os import environ
|
||||
environ['PYGAME_HIDE_SUPPORT_PROMPT'] = '1'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
n_games = 300
|
||||
n_episodes = 1000
|
||||
|
||||
figure_file = 'plots/score.png'
|
||||
score_history = []
|
||||
|
@ -54,22 +15,24 @@ if __name__ == '__main__':
|
|||
|
||||
agent_list = []
|
||||
|
||||
game_len = 10000
|
||||
game_len = 5000
|
||||
|
||||
game = Game()
|
||||
|
||||
for i in tqdm(range(n_games)):
|
||||
for i in tqdm(range(n_episodes)):
|
||||
# TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste)
|
||||
game.level.__init__()
|
||||
if i != 0:
|
||||
game.level.__init__(reset=True)
|
||||
# TODO: Make game.level.reset_map() so we don't pull out and load the agent every time (There is -definitevly- a better way)
|
||||
for player in game.level.player_sprites:
|
||||
for player_id, agent in agent_list:
|
||||
if player.player_id == player_id:
|
||||
player.agent = agent
|
||||
|
||||
agent_list = []
|
||||
done = False
|
||||
score = 0
|
||||
for _ in range(game_len):
|
||||
for _ in tqdm(range(game_len)):
|
||||
if not game.level.done:
|
||||
game.run()
|
||||
else:
|
||||
|
@ -77,8 +40,9 @@ if __name__ == '__main__':
|
|||
for player in game.level.player_sprites:
|
||||
agent_list.append((player.player_id, player.agent))
|
||||
|
||||
if i == n_games-1 and game.level.enemy_sprites != []:
|
||||
if i == n_episodes-1 and game.level.enemy_sprites != []:
|
||||
for player in game.level.player_sprites:
|
||||
for enemy in game.level.enemy_sprites:
|
||||
player.stats.exp -= 5
|
||||
player.update()
|
||||
|
||||
|
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in a new issue