Implemented (badly) agent
This commit is contained in:
parent
d20f46bb9d
commit
b4a6e99fce
38 changed files with 200 additions and 158 deletions
|
@ -1,14 +1,12 @@
|
||||||
import random
|
import numpy as np
|
||||||
import torch
|
import torch as T
|
||||||
|
|
||||||
from numpy.random import default_rng
|
from .brain import ActorNetwork, CriticNetwork, PPOMemory
|
||||||
|
|
||||||
from rl.brain import ActorNetwork, CriticNetwork, PPOMemory
|
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
|
|
||||||
def __init__(self, n_actions, input_dims, gamma = 0.99, alpha = 0.0003, policy_clip = 0.2, batch_size = 64, N=2048, n_epochs = 10, gae_lambda = 0.95):
|
def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003, policy_clip=0.2, batch_size=64, N=2048, n_epochs=10, gae_lambda=0.95):
|
||||||
|
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.policy_clip = policy_clip
|
self.policy_clip = policy_clip
|
||||||
|
@ -33,13 +31,13 @@ class Agent:
|
||||||
print('... done ...')
|
print('... done ...')
|
||||||
|
|
||||||
def load_models(self):
|
def load_models(self):
|
||||||
print('... loadng models ...')
|
print('... loading models ...')
|
||||||
self.actor.load_checkpoint()
|
self.actor.load_checkpoint()
|
||||||
self.critic.load_checkpoint()
|
self.critic.load_checkpoint()
|
||||||
print('.. done ...')
|
print('.. done ...')
|
||||||
|
|
||||||
def choose_action(self, observation):
|
def choose_action(self, observation):
|
||||||
state = T.tensor([observation], dtype = T.float).to(self.actor.device)
|
state = observation.to(self.actor.device, dtype=T.float)
|
||||||
|
|
||||||
dist = self.actor(state)
|
dist = self.actor(state)
|
||||||
value = self.critic(state)
|
value = self.critic(state)
|
||||||
|
@ -62,15 +60,19 @@ class Agent:
|
||||||
discount = 1
|
discount = 1
|
||||||
a_t = 0
|
a_t = 0
|
||||||
for k in range(t, len(reward_arr)-1):
|
for k in range(t, len(reward_arr)-1):
|
||||||
a_t += discount*(reward_arr[k] + self.gamma*values[k+1]*(1-int(dones_arr[k])) - values[k])
|
a_t += discount * \
|
||||||
|
(reward_arr[k] + self.gamma*values[k+1]
|
||||||
|
* (1-int(done_arr[k])) - values[k])
|
||||||
discount *= self.gamma * self.gae_lambda
|
discount *= self.gamma * self.gae_lambda
|
||||||
advantage[t] = a_t
|
advantage[t] = a_t
|
||||||
advantage = T.tensor(Advantage).to(self.actor.device)
|
advantage = T.tensor(advantage).to(self.actor.device)
|
||||||
|
|
||||||
values = T.tensor(values).to(self.actor.device)
|
values = T.tensor(values).to(self.actor.device)
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
states = T.tensor(state_arr[batch], dtype = T.float).to(self.actor.device)
|
states = T.tensor(state_arr[batch], dtype=T.float).to(
|
||||||
old_probs = T.tensor(old_probs_arr[batch]).to(self.actor.device)
|
self.actor.device)
|
||||||
|
old_probs = T.tensor(old_probs_arr[batch]).to(
|
||||||
|
self.actor.device)
|
||||||
actions = T.tensor(action_arr[batch]).to(self.actor.device)
|
actions = T.tensor(action_arr[batch]).to(self.actor.device)
|
||||||
|
|
||||||
dist = self.actor(states)
|
dist = self.actor(states)
|
||||||
|
@ -81,8 +83,10 @@ class Agent:
|
||||||
new_probs = dist.log_prob(actions)
|
new_probs = dist.log_prob(actions)
|
||||||
prob_ratio = new_probs.exp() / old_probs.exp()
|
prob_ratio = new_probs.exp() / old_probs.exp()
|
||||||
weighted_probs = advantage[batch] * prob_ratio
|
weighted_probs = advantage[batch] * prob_ratio
|
||||||
weighted_clipped_probs = T.clamp(prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch]
|
weighted_clipped_probs = T.clamp(
|
||||||
actor_loss = -T.min(weighted_probs, weighted_clipped_probs).mean()
|
prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch]
|
||||||
|
actor_loss = -T.min(weighted_probs,
|
||||||
|
weighted_clipped_probs).mean()
|
||||||
|
|
||||||
returns = advantage[batch] + values[batch]
|
returns = advantage[batch] + values[batch]
|
||||||
critic_loss = (returns - critic_value)**2
|
critic_loss = (returns - critic_value)**2
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.distributions.categorical import Categorical
|
from torch.distributions.categorical import Categorical
|
||||||
|
|
||||||
|
|
||||||
class PPOMemory:
|
class PPOMemory:
|
||||||
def __init__(self, batch_size):
|
def __init__(self, batch_size):
|
||||||
self.states = []
|
self.states = []
|
||||||
|
@ -47,6 +48,7 @@ class PPOMemory:
|
||||||
self.rewards = []
|
self.rewards = []
|
||||||
self.dones = []
|
self.dones = []
|
||||||
|
|
||||||
|
|
||||||
class ActorNetwork(nn.Module):
|
class ActorNetwork(nn.Module):
|
||||||
|
|
||||||
def __init__(self, input_dim, output_dim, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
|
def __init__(self, input_dim, output_dim, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
|
||||||
|
@ -54,17 +56,18 @@ class ActorNetwork(nn.Module):
|
||||||
|
|
||||||
self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
|
self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
|
||||||
self.actor = nn.Sequential(
|
self.actor = nn.Sequential(
|
||||||
nn.Linear(len(input_dim), fc1_dims),
|
nn.Linear(input_dim, fc1_dims),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(fc1_dims, fc2_dims),
|
nn.Linear(fc1_dims, fc2_dims),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(fc2_dims, len(output_dim)),
|
nn.Linear(fc2_dims, output_dim),
|
||||||
nn.Softmax(dim=-1)
|
nn.Softmax(dim=-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
|
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
|
||||||
|
|
||||||
self.device = T.device('cuda:0' if T.cuda.is_available() else ('mps' if T.backends.mps.is_available() else 'cpu'))
|
self.device = T.device('cuda:0' if T.cuda.is_available() else (
|
||||||
|
'mps' if T.backends.mps.is_available() else 'cpu'))
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
|
@ -80,6 +83,7 @@ class ActorNetwork(nn.Module):
|
||||||
def load_checkpoint(self):
|
def load_checkpoint(self):
|
||||||
self.load_state_dict(T.load(self.checkpoint_file))
|
self.load_state_dict(T.load(self.checkpoint_file))
|
||||||
|
|
||||||
|
|
||||||
class CriticNetwork(nn.Module):
|
class CriticNetwork(nn.Module):
|
||||||
|
|
||||||
def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
|
def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
|
||||||
|
@ -87,7 +91,7 @@ class CriticNetwork(nn.Module):
|
||||||
|
|
||||||
self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
|
self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
|
||||||
self.critic = nn.Sequential(
|
self.critic = nn.Sequential(
|
||||||
nn.Linear(len(input_dims), fc1_dims),
|
nn.Linear(input_dims, fc1_dims),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(fc1_dims, fc2_dims),
|
nn.Linear(fc1_dims, fc2_dims),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
|
@ -95,12 +99,13 @@ class CriticNetwork(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
|
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
|
||||||
self.device = T.device('cuda:0' if T.cuda.is_available() else ('mps' if T.backends.mps.is_available() else 'cpu'))
|
self.device = T.device('cuda:0' if T.cuda.is_available() else (
|
||||||
|
'mps' if T.backends.mps.is_available() else 'cpu'))
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
vale = self.critic(state)
|
value = self.critic(state)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@ -109,19 +114,3 @@ class CriticNetwork(nn.Module):
|
||||||
|
|
||||||
def load_checkpoint(self):
|
def load_checkpoint(self):
|
||||||
self.load_state_dict(T.load(self.checkpoint_file))
|
self.load_state_dict(T.load(self.checkpoint_file))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../../..', 'assets')
|
script_dir, '../..', 'assets')
|
||||||
|
|
||||||
monster_data = {
|
monster_data = {
|
||||||
'squid': {'id': 1, 'health': 100, 'exp': 100, 'attack': 20, 'attack_type': 'slash', 'attack_sound': f'{asset_path}/audio/attack/slash.wav', 'speed': 3, 'knockback': 20, 'attack_radius': 80, 'notice_radius': 360},
|
'squid': {'id': 1, 'health': 100, 'exp': 100, 'attack': 20, 'attack_type': 'slash', 'attack_sound': f'{asset_path}/audio/attack/slash.wav', 'speed': 3, 'knockback': 20, 'attack_radius': 80, 'notice_radius': 360},
|
|
@ -2,7 +2,7 @@ import os
|
||||||
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../../..', 'assets')
|
script_dir, '../..', 'assets')
|
||||||
|
|
||||||
magic_data = {
|
magic_data = {
|
||||||
'flame': {'strength': 5, 'cost': 20, 'graphic': f"{asset_path}/graphics/particles/flame/fire.png"},
|
'flame': {'strength': 5, 'cost': 20, 'graphic': f"{asset_path}/graphics/particles/flame/fire.png"},
|
|
@ -2,7 +2,7 @@ import os
|
||||||
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../../..', 'assets')
|
script_dir, '../..', 'assets')
|
||||||
|
|
||||||
weapon_data = {
|
weapon_data = {
|
||||||
'sword': {'cooldown': 100, 'damage': 15, 'graphic': f"{asset_path}/graphics/weapons/sword/full.png"},
|
'sword': {'cooldown': 100, 'damage': 15, 'graphic': f"{asset_path}/graphics/weapons/sword/full.png"},
|
|
@ -10,7 +10,7 @@ class MagicPlayer:
|
||||||
self.animation_player = animation_player
|
self.animation_player = animation_player
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../..', 'assets')
|
script_dir, '..', 'assets')
|
||||||
|
|
||||||
# Sound Setup
|
# Sound Setup
|
||||||
self.sounds = {
|
self.sounds = {
|
|
@ -10,7 +10,7 @@ class AnimationPlayer:
|
||||||
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../..', 'assets')
|
script_dir, '..', 'assets')
|
||||||
|
|
||||||
self.frames = {
|
self.frames = {
|
||||||
# magic
|
# magic
|
|
@ -9,7 +9,7 @@ class Weapon(pygame.sprite.Sprite):
|
||||||
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../..', 'assets')
|
script_dir, '..', 'assets')
|
||||||
|
|
||||||
self.sprite_type = 'weapon'
|
self.sprite_type = 'weapon'
|
||||||
direction = player._input.status.split('_')[0]
|
direction = player._input.status.split('_')[0]
|
|
@ -36,18 +36,15 @@ class InputHandler:
|
||||||
self.magic_swap_time = None
|
self.magic_swap_time = None
|
||||||
|
|
||||||
# Setup Action Space
|
# Setup Action Space
|
||||||
|
self.num_actions = 7
|
||||||
self.action = 10
|
self.action = 10
|
||||||
|
|
||||||
def check_input(self, speed, hitbox, obstacle_sprites, rect, player):
|
def check_input(self, button, speed, hitbox, obstacle_sprites, rect, player):
|
||||||
|
|
||||||
self.action = 10
|
self.action = 10
|
||||||
|
|
||||||
if not self.attacking and self.can_move:
|
if not self.attacking and self.can_move:
|
||||||
|
|
||||||
keys = pygame.key.get_pressed()
|
|
||||||
|
|
||||||
button = randint(0, 4)
|
|
||||||
|
|
||||||
self.move_time = pygame.time.get_ticks()
|
self.move_time = pygame.time.get_ticks()
|
||||||
|
|
||||||
# Movement Input
|
# Movement Input
|
|
@ -19,7 +19,7 @@ class AnimationHandler:
|
||||||
def import_assets(self, position):
|
def import_assets(self, position):
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../../..', 'assets', 'graphics')
|
script_dir, '../..', 'assets', 'graphics')
|
||||||
|
|
||||||
if self.sprite_type == 'player':
|
if self.sprite_type == 'player':
|
||||||
|
|
|
@ -9,7 +9,7 @@ class AudioHandler:
|
||||||
def __init__(self, sprite_type, monster_name=None):
|
def __init__(self, sprite_type, monster_name=None):
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../../..', 'assets', 'audio')
|
script_dir, '../..', 'assets', 'audio')
|
||||||
|
|
||||||
if sprite_type == 'player':
|
if sprite_type == 'player':
|
||||||
pass
|
pass
|
|
@ -33,7 +33,7 @@ class CombatHandler:
|
||||||
# Import Sounds
|
# Import Sounds
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../../..', 'assets', 'audio')
|
script_dir, '../..', 'assets', 'audio')
|
||||||
|
|
||||||
self.weapon_attack_sound = pygame.mixer.Sound(
|
self.weapon_attack_sound = pygame.mixer.Sound(
|
||||||
f"{asset_path}/sword.wav")
|
f"{asset_path}/sword.wav")
|
|
@ -17,6 +17,7 @@ class Enemy(pygame.sprite.Sprite):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.visible_sprites = visible_sprites
|
self.visible_sprites = visible_sprites
|
||||||
|
|
||||||
|
self.position = position
|
||||||
# Setup Graphics
|
# Setup Graphics
|
||||||
self.audio = AudioHandler(self.sprite_type, self.name)
|
self.audio = AudioHandler(self.sprite_type, self.name)
|
||||||
self.animation_player = AnimationPlayer()
|
self.animation_player = AnimationPlayer()
|
|
@ -13,7 +13,7 @@ class Observer(pygame.sprite.Sprite):
|
||||||
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../..', 'assets')
|
script_dir, '..', 'assets')
|
||||||
|
|
||||||
self.image = pygame.image.load(
|
self.image = pygame.image.load(
|
||||||
f"{asset_path}/graphics/observer.png").convert_alpha()
|
f"{asset_path}/graphics/observer.png").convert_alpha()
|
|
@ -10,14 +10,18 @@ from .components.animaton import AnimationHandler
|
||||||
|
|
||||||
from effects.particle_effects import AnimationPlayer
|
from effects.particle_effects import AnimationPlayer
|
||||||
|
|
||||||
|
from agent.agent import Agent
|
||||||
|
|
||||||
|
|
||||||
class Player(pygame.sprite.Sprite):
|
class Player(pygame.sprite.Sprite):
|
||||||
|
|
||||||
def __init__(self, position, groups, obstacle_sprites, visible_sprites, attack_sprites, attackable_sprites, role):
|
def __init__(self, position, groups, obstacle_sprites, visible_sprites, attack_sprites, attackable_sprites, role, player_id, extract_features, convert_features_to_tensor):
|
||||||
super().__init__(groups)
|
super().__init__(groups)
|
||||||
|
|
||||||
# Setup Sprites
|
# Setup Sprites
|
||||||
self.sprite_type = 'player'
|
self.sprite_type = 'player'
|
||||||
|
self.status = 'down'
|
||||||
|
self.player_id = player_id
|
||||||
self.visible_sprites = visible_sprites
|
self.visible_sprites = visible_sprites
|
||||||
self.attack_sprites = attack_sprites
|
self.attack_sprites = attack_sprites
|
||||||
self.obstacle_sprites = obstacle_sprites
|
self.obstacle_sprites = obstacle_sprites
|
||||||
|
@ -40,6 +44,14 @@ class Player(pygame.sprite.Sprite):
|
||||||
|
|
||||||
self.distance_direction_from_enemy = None
|
self.distance_direction_from_enemy = None
|
||||||
|
|
||||||
|
# Setup AI
|
||||||
|
self.extract_features = extract_features
|
||||||
|
self.convert_features_to_tensor = convert_features_to_tensor
|
||||||
|
self.agent = Agent(input_dims=398, n_actions=self._input.num_actions)
|
||||||
|
self.state_tensor = None
|
||||||
|
self.action_tensor = None
|
||||||
|
self.reward_tensor = None
|
||||||
|
|
||||||
def get_status(self):
|
def get_status(self):
|
||||||
if self._input.movement.direction.x == 0 and self._input.movement.direction.y == 0:
|
if self._input.movement.direction.x == 0 and self._input.movement.direction.y == 0:
|
||||||
if 'idle' not in self.status and 'attack' not in self.status:
|
if 'idle' not in self.status and 'attack' not in self.status:
|
||||||
|
@ -85,10 +97,41 @@ class Player(pygame.sprite.Sprite):
|
||||||
spell_damage = magic_data[self._input.combat.magic]['strength']
|
spell_damage = magic_data[self._input.combat.magic]['strength']
|
||||||
return (base_damage + spell_damage)
|
return (base_damage + spell_damage)
|
||||||
|
|
||||||
|
def get_current_state(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def is_dead(self):
|
||||||
|
if self.stats.health == 0:
|
||||||
|
self.stats.exp = -10
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
self.extract_features()
|
||||||
|
self.convert_features_to_tensor()
|
||||||
|
|
||||||
|
# Choose action based on current state
|
||||||
|
action, probs, value = self.agent.choose_action(self.state_tensor)
|
||||||
|
|
||||||
|
print(action)
|
||||||
|
# Apply chosen action
|
||||||
|
self._input.check_input(action, self.stats.speed, self.animation.hitbox,
|
||||||
|
self.obstacle_sprites, self.animation.rect, self)
|
||||||
|
|
||||||
|
done = self.is_dead()
|
||||||
|
|
||||||
|
self.extract_features()
|
||||||
|
self.convert_features_to_tensor()
|
||||||
|
|
||||||
|
self.agent.remember(self.state_tensor, self.action_tensor,
|
||||||
|
probs, value, self.reward_tensor, done)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
self.agent.learn()
|
||||||
|
self.agent.memory.clear_memory()
|
||||||
|
|
||||||
# Refresh objects based on input
|
# Refresh objects based on input
|
||||||
self._input.check_input(
|
|
||||||
self.stats.speed, self.animation.hitbox, self.obstacle_sprites, self.animation.rect, self)
|
|
||||||
self.status = self._input.status
|
self.status = self._input.status
|
||||||
|
|
||||||
# Animate
|
# Animate
|
|
@ -2,7 +2,7 @@ import os
|
||||||
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../..', 'assets')
|
script_dir, '..', 'assets')
|
||||||
|
|
||||||
# ui
|
# ui
|
||||||
BAR_HEIGHT = 20
|
BAR_HEIGHT = 20
|
|
@ -16,7 +16,7 @@ class Camera(pygame.sprite.Group):
|
||||||
# Creating the floor
|
# Creating the floor
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
image_path = os.path.join(
|
image_path = os.path.join(
|
||||||
script_dir, '../..', 'assets', 'graphics', 'tilemap', 'ground.png')
|
script_dir, '..', 'assets', 'graphics', 'tilemap', 'ground.png')
|
||||||
|
|
||||||
self.floor_surf = pygame.image.load(image_path).convert()
|
self.floor_surf = pygame.image.load(image_path).convert()
|
||||||
self.floor_rect = self.floor_surf.get_rect(topleft=(0, 0))
|
self.floor_rect = self.floor_surf.get_rect(topleft=(0, 0))
|
|
@ -9,7 +9,6 @@ from utils.debug import debug
|
||||||
from utils.resource_loader import import_csv_layout, import_folder
|
from utils.resource_loader import import_csv_layout, import_folder
|
||||||
|
|
||||||
from interface.ui import UI
|
from interface.ui import UI
|
||||||
from interface.upgrade import Upgrade
|
|
||||||
|
|
||||||
from entities.observer import Observer
|
from entities.observer import Observer
|
||||||
from entities.player import Player
|
from entities.player import Player
|
||||||
|
@ -21,11 +20,16 @@ from .camera import Camera
|
||||||
|
|
||||||
class Level:
|
class Level:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, extract_features,
|
||||||
|
convert_features_to_tensor):
|
||||||
|
|
||||||
# General Settings
|
# General Settings
|
||||||
self.game_paused = False
|
self.game_paused = False
|
||||||
|
|
||||||
|
# AI setup
|
||||||
|
self.extract_features = extract_features
|
||||||
|
self.convert_features_to_tensor = convert_features_to_tensor
|
||||||
|
|
||||||
# Get display surface
|
# Get display surface
|
||||||
self.display_surface = pygame.display.get_surface()
|
self.display_surface = pygame.display.get_surface()
|
||||||
|
|
||||||
|
@ -37,18 +41,17 @@ class Level:
|
||||||
|
|
||||||
# Sprite setup and entity generation
|
# Sprite setup and entity generation
|
||||||
self.create_map()
|
self.create_map()
|
||||||
|
|
||||||
# UI setup
|
|
||||||
self.ui = UI()
|
|
||||||
# self.upgrade = Upgrade(self.player)
|
|
||||||
|
|
||||||
self.get_players_enemies()
|
self.get_players_enemies()
|
||||||
self.get_distance_direction()
|
self.get_distance_direction()
|
||||||
|
|
||||||
|
# UI setup
|
||||||
|
self.ui = UI()
|
||||||
|
|
||||||
def create_map(self):
|
def create_map(self):
|
||||||
|
player_id = 0
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
asset_path = os.path.join(
|
asset_path = os.path.join(
|
||||||
script_dir, '../..', 'assets')
|
script_dir, '..', 'assets')
|
||||||
layouts = {
|
layouts = {
|
||||||
'boundary': import_csv_layout(f"{asset_path}/map/FloorBlocks.csv"),
|
'boundary': import_csv_layout(f"{asset_path}/map/FloorBlocks.csv"),
|
||||||
'grass': import_csv_layout(f"{asset_path}/map/Grass.csv"),
|
'grass': import_csv_layout(f"{asset_path}/map/Grass.csv"),
|
||||||
|
@ -89,17 +92,20 @@ class Level:
|
||||||
elif col == '400':
|
elif col == '400':
|
||||||
# Player Generation
|
# Player Generation
|
||||||
Player(
|
Player(
|
||||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'tank')
|
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'tank', player_id, self.extract_features, self.convert_features_to_tensor)
|
||||||
|
player_id += 1
|
||||||
|
|
||||||
elif col == '401':
|
elif col == '401':
|
||||||
# Player Generation
|
# Player Generation
|
||||||
Player(
|
Player(
|
||||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'warrior')
|
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'warrior', player_id, self.extract_features, self.convert_features_to_tensor)
|
||||||
|
player_id += 1
|
||||||
|
|
||||||
elif col == '402':
|
elif col == '402':
|
||||||
# Player Generation
|
# Player Generation
|
||||||
Player(
|
Player(
|
||||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'mage')
|
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'mage', player_id, self.extract_features, self.convert_features_to_tensor)
|
||||||
|
player_id += 1
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Monster Generation
|
# Monster Generation
|
||||||
|
@ -167,6 +173,14 @@ class Level:
|
||||||
|
|
||||||
debug('v0.6')
|
debug('v0.6')
|
||||||
|
|
||||||
|
for player in self.player_sprites:
|
||||||
|
if player.is_dead():
|
||||||
|
print(player.stats.health)
|
||||||
|
player.kill()
|
||||||
|
|
||||||
|
if self.player_sprites == []:
|
||||||
|
self.__init__()
|
||||||
|
|
||||||
if not self.game_paused:
|
if not self.game_paused:
|
||||||
# Update the game
|
# Update the game
|
||||||
for player in self.player_sprites:
|
for player in self.player_sprites:
|
||||||
|
@ -177,11 +191,5 @@ class Level:
|
||||||
self.apply_damage_to_player()
|
self.apply_damage_to_player()
|
||||||
self.visible_sprites.update()
|
self.visible_sprites.update()
|
||||||
|
|
||||||
# self.visible_sprites.enemy_update(self.player)
|
|
||||||
# self.player_attack_logic()
|
|
||||||
else:
|
else:
|
||||||
debug('PAUSED')
|
debug('PAUSED')
|
||||||
|
|
||||||
for player in self.player_sprites:
|
|
||||||
if player.stats.health <= 0:
|
|
||||||
player.kill()
|
|
|
@ -19,10 +19,11 @@ class Game:
|
||||||
pygame.display.set_caption('Pneuma')
|
pygame.display.set_caption('Pneuma')
|
||||||
self.clock = pygame.time.Clock()
|
self.clock = pygame.time.Clock()
|
||||||
|
|
||||||
self.level = Level()
|
self.level = Level(self.extract_features,
|
||||||
|
self.convert_features_to_tensor)
|
||||||
|
|
||||||
# Sound
|
# Sound
|
||||||
main_sound = pygame.mixer.Sound('../assets/audio/main.ogg')
|
main_sound = pygame.mixer.Sound('assets/audio/main.ogg')
|
||||||
main_sound.set_volume(0.4)
|
main_sound.set_volume(0.4)
|
||||||
main_sound.play(loops=-1)
|
main_sound.play(loops=-1)
|
||||||
|
|
||||||
|
@ -35,14 +36,17 @@ class Game:
|
||||||
for i, player in enumerate(self.level.player_sprites):
|
for i, player in enumerate(self.level.player_sprites):
|
||||||
|
|
||||||
player_action_features = {
|
player_action_features = {
|
||||||
|
"player_id": player.player_id,
|
||||||
"player_action": player._input.action
|
"player_action": player._input.action
|
||||||
}
|
}
|
||||||
|
|
||||||
player_reward_features = {
|
player_reward_features = {
|
||||||
|
"player_id": player.player_id,
|
||||||
"player_exp": player.stats.exp
|
"player_exp": player.stats.exp
|
||||||
}
|
}
|
||||||
|
|
||||||
player_state_features = {
|
player_state_features = {
|
||||||
|
"player_id": player.player_id,
|
||||||
"player_position": player.rect.center,
|
"player_position": player.rect.center,
|
||||||
"player role": player.stats.role_id,
|
"player role": player.stats.role_id,
|
||||||
"player_health": player.stats.health,
|
"player_health": player.stats.health,
|
||||||
|
@ -80,10 +84,6 @@ class Game:
|
||||||
|
|
||||||
def convert_features_to_tensor(self):
|
def convert_features_to_tensor(self):
|
||||||
|
|
||||||
self.state_tensors = []
|
|
||||||
self.action_tensors = []
|
|
||||||
self.reward_tensors = []
|
|
||||||
|
|
||||||
for features in self.state_features:
|
for features in self.state_features:
|
||||||
info_array = []
|
info_array = []
|
||||||
|
|
||||||
|
@ -125,8 +125,9 @@ class Game:
|
||||||
state_tensor = torch.tensor(
|
state_tensor = torch.tensor(
|
||||||
np.array(info_array, dtype=np.float32))
|
np.array(info_array, dtype=np.float32))
|
||||||
|
|
||||||
self.state_tensors.append(state_tensor)
|
for player in self.level.player_sprites:
|
||||||
|
if player.player_id == features["player_id"]:
|
||||||
|
player.state_tensor = state_tensor
|
||||||
|
|
||||||
for features in self.action_features:
|
for features in self.action_features:
|
||||||
info_array = []
|
info_array = []
|
||||||
|
@ -139,7 +140,9 @@ class Game:
|
||||||
action_tensor = torch.tensor(
|
action_tensor = torch.tensor(
|
||||||
np.array(action_info, dtype=np.float32))
|
np.array(action_info, dtype=np.float32))
|
||||||
|
|
||||||
self.action_tensors.append(action_tensor)
|
for player in self.level.player_sprites:
|
||||||
|
if player.player_id == features["player_id"]:
|
||||||
|
player.action_tensor = action_tensor
|
||||||
|
|
||||||
for features in self.reward_features:
|
for features in self.reward_features:
|
||||||
info_array = []
|
info_array = []
|
||||||
|
@ -152,7 +155,9 @@ class Game:
|
||||||
reward_tensor = torch.tensor(
|
reward_tensor = torch.tensor(
|
||||||
np.array(reward_info, dtype=np.float32))
|
np.array(reward_info, dtype=np.float32))
|
||||||
|
|
||||||
self.reward_tensors.append(reward_tensor)
|
for player in self.level.player_sprites:
|
||||||
|
if player.player_id == features["player_id"]:
|
||||||
|
player.reward_tensor = reward_tensor
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
||||||
|
@ -166,10 +171,8 @@ class Game:
|
||||||
|
|
||||||
self.screen.fill(WATER_COLOR)
|
self.screen.fill(WATER_COLOR)
|
||||||
|
|
||||||
self.extract_features()
|
self.level.run(who='observer')
|
||||||
self.convert_features_to_tensor()
|
|
||||||
|
|
||||||
self.level.run('observer')
|
|
||||||
pygame.display.update()
|
pygame.display.update()
|
||||||
self.clock.tick(FPS)
|
self.clock.tick(FPS)
|
||||||
|
|
||||||
|
@ -179,9 +182,4 @@ if __name__ == '__main__':
|
||||||
game = Game()
|
game = Game()
|
||||||
for i in range(0, 10000):
|
for i in range(0, 10000):
|
||||||
game.run()
|
game.run()
|
||||||
game.extract_features()
|
print(i)
|
||||||
game.convert_features_to_tensor()
|
|
||||||
if i == 100:
|
|
||||||
print(game.reward_tensors)
|
|
||||||
print(game.action_tensors)
|
|
||||||
print(game.state_tensors)
|
|
|
@ -2,6 +2,7 @@ import pygame
|
||||||
from csv import reader
|
from csv import reader
|
||||||
from os import walk
|
from os import walk
|
||||||
|
|
||||||
|
|
||||||
def import_csv_layout(path):
|
def import_csv_layout(path):
|
||||||
terrain_map = []
|
terrain_map = []
|
||||||
with open(path) as level_map:
|
with open(path) as level_map:
|
||||||
|
@ -10,6 +11,7 @@ def import_csv_layout(path):
|
||||||
terrain_map.append(list(row))
|
terrain_map.append(list(row))
|
||||||
return terrain_map
|
return terrain_map
|
||||||
|
|
||||||
|
|
||||||
def import_folder(path):
|
def import_folder(path):
|
||||||
surface_list = []
|
surface_list = []
|
||||||
|
|
Loading…
Reference in a new issue