Implemented (badly) agent

This commit is contained in:
Vasilis Valatsos 2023-11-14 22:44:43 +01:00
parent d20f46bb9d
commit b4a6e99fce
38 changed files with 200 additions and 158 deletions

View file

@ -1,14 +1,12 @@
import random import numpy as np
import torch import torch as T
from numpy.random import default_rng from .brain import ActorNetwork, CriticNetwork, PPOMemory
from rl.brain import ActorNetwork, CriticNetwork, PPOMemory
class Agent: class Agent:
def __init__(self, n_actions, input_dims, gamma = 0.99, alpha = 0.0003, policy_clip = 0.2, batch_size = 64, N=2048, n_epochs = 10, gae_lambda = 0.95): def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003, policy_clip=0.2, batch_size=64, N=2048, n_epochs=10, gae_lambda=0.95):
self.gamma = gamma self.gamma = gamma
self.policy_clip = policy_clip self.policy_clip = policy_clip
@ -33,13 +31,13 @@ class Agent:
print('... done ...') print('... done ...')
def load_models(self): def load_models(self):
print('... loadng models ...') print('... loading models ...')
self.actor.load_checkpoint() self.actor.load_checkpoint()
self.critic.load_checkpoint() self.critic.load_checkpoint()
print('.. done ...') print('.. done ...')
def choose_action(self, observation): def choose_action(self, observation):
state = T.tensor([observation], dtype = T.float).to(self.actor.device) state = observation.to(self.actor.device, dtype=T.float)
dist = self.actor(state) dist = self.actor(state)
value = self.critic(state) value = self.critic(state)
@ -62,15 +60,19 @@ class Agent:
discount = 1 discount = 1
a_t = 0 a_t = 0
for k in range(t, len(reward_arr)-1): for k in range(t, len(reward_arr)-1):
a_t += discount*(reward_arr[k] + self.gamma*values[k+1]*(1-int(dones_arr[k])) - values[k]) a_t += discount * \
(reward_arr[k] + self.gamma*values[k+1]
* (1-int(done_arr[k])) - values[k])
discount *= self.gamma * self.gae_lambda discount *= self.gamma * self.gae_lambda
advantage[t] = a_t advantage[t] = a_t
advantage = T.tensor(Advantage).to(self.actor.device) advantage = T.tensor(advantage).to(self.actor.device)
values = T.tensor(values).to(self.actor.device) values = T.tensor(values).to(self.actor.device)
for batch in batches: for batch in batches:
states = T.tensor(state_arr[batch], dtype = T.float).to(self.actor.device) states = T.tensor(state_arr[batch], dtype=T.float).to(
old_probs = T.tensor(old_probs_arr[batch]).to(self.actor.device) self.actor.device)
old_probs = T.tensor(old_probs_arr[batch]).to(
self.actor.device)
actions = T.tensor(action_arr[batch]).to(self.actor.device) actions = T.tensor(action_arr[batch]).to(self.actor.device)
dist = self.actor(states) dist = self.actor(states)
@ -81,8 +83,10 @@ class Agent:
new_probs = dist.log_prob(actions) new_probs = dist.log_prob(actions)
prob_ratio = new_probs.exp() / old_probs.exp() prob_ratio = new_probs.exp() / old_probs.exp()
weighted_probs = advantage[batch] * prob_ratio weighted_probs = advantage[batch] * prob_ratio
weighted_clipped_probs = T.clamp(prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch] weighted_clipped_probs = T.clamp(
actor_loss = -T.min(weighted_probs, weighted_clipped_probs).mean() prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch]
actor_loss = -T.min(weighted_probs,
weighted_clipped_probs).mean()
returns = advantage[batch] + values[batch] returns = advantage[batch] + values[batch]
critic_loss = (returns - critic_value)**2 critic_loss = (returns - critic_value)**2

View file

@ -5,6 +5,7 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.distributions.categorical import Categorical from torch.distributions.categorical import Categorical
class PPOMemory: class PPOMemory:
def __init__(self, batch_size): def __init__(self, batch_size):
self.states = [] self.states = []
@ -47,6 +48,7 @@ class PPOMemory:
self.rewards = [] self.rewards = []
self.dones = [] self.dones = []
class ActorNetwork(nn.Module): class ActorNetwork(nn.Module):
def __init__(self, input_dim, output_dim, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'): def __init__(self, input_dim, output_dim, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
@ -54,17 +56,18 @@ class ActorNetwork(nn.Module):
self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo') self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
self.actor = nn.Sequential( self.actor = nn.Sequential(
nn.Linear(len(input_dim), fc1_dims), nn.Linear(input_dim, fc1_dims),
nn.ReLU(), nn.ReLU(),
nn.Linear(fc1_dims, fc2_dims), nn.Linear(fc1_dims, fc2_dims),
nn.ReLU(), nn.ReLU(),
nn.Linear(fc2_dims, len(output_dim)), nn.Linear(fc2_dims, output_dim),
nn.Softmax(dim=-1) nn.Softmax(dim=-1)
) )
self.optimizer = optim.Adam(self.parameters(), lr=alpha) self.optimizer = optim.Adam(self.parameters(), lr=alpha)
self.device = T.device('cuda:0' if T.cuda.is_available() else ('mps' if T.backends.mps.is_available() else 'cpu')) self.device = T.device('cuda:0' if T.cuda.is_available() else (
'mps' if T.backends.mps.is_available() else 'cpu'))
self.to(self.device) self.to(self.device)
@ -80,6 +83,7 @@ class ActorNetwork(nn.Module):
def load_checkpoint(self): def load_checkpoint(self):
self.load_state_dict(T.load(self.checkpoint_file)) self.load_state_dict(T.load(self.checkpoint_file))
class CriticNetwork(nn.Module): class CriticNetwork(nn.Module):
def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'): def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
@ -87,7 +91,7 @@ class CriticNetwork(nn.Module):
self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo') self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
self.critic = nn.Sequential( self.critic = nn.Sequential(
nn.Linear(len(input_dims), fc1_dims), nn.Linear(input_dims, fc1_dims),
nn.ReLU(), nn.ReLU(),
nn.Linear(fc1_dims, fc2_dims), nn.Linear(fc1_dims, fc2_dims),
nn.ReLU(), nn.ReLU(),
@ -95,12 +99,13 @@ class CriticNetwork(nn.Module):
) )
self.optimizer = optim.Adam(self.parameters(), lr=alpha) self.optimizer = optim.Adam(self.parameters(), lr=alpha)
self.device = T.device('cuda:0' if T.cuda.is_available() else ('mps' if T.backends.mps.is_available() else 'cpu')) self.device = T.device('cuda:0' if T.cuda.is_available() else (
'mps' if T.backends.mps.is_available() else 'cpu'))
self.to(self.device) self.to(self.device)
def forward(self, state): def forward(self, state):
vale = self.critic(state) value = self.critic(state)
return value return value
@ -109,19 +114,3 @@ class CriticNetwork(nn.Module):
def load_checkpoint(self): def load_checkpoint(self):
self.load_state_dict(T.load(self.checkpoint_file)) self.load_state_dict(T.load(self.checkpoint_file))

View file

@ -3,7 +3,7 @@ import os
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../../..', 'assets') script_dir, '../..', 'assets')
monster_data = { monster_data = {
'squid': {'id': 1, 'health': 100, 'exp': 100, 'attack': 20, 'attack_type': 'slash', 'attack_sound': f'{asset_path}/audio/attack/slash.wav', 'speed': 3, 'knockback': 20, 'attack_radius': 80, 'notice_radius': 360}, 'squid': {'id': 1, 'health': 100, 'exp': 100, 'attack': 20, 'attack_type': 'slash', 'attack_sound': f'{asset_path}/audio/attack/slash.wav', 'speed': 3, 'knockback': 20, 'attack_radius': 80, 'notice_radius': 360},

View file

@ -2,7 +2,7 @@ import os
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../../..', 'assets') script_dir, '../..', 'assets')
magic_data = { magic_data = {
'flame': {'strength': 5, 'cost': 20, 'graphic': f"{asset_path}/graphics/particles/flame/fire.png"}, 'flame': {'strength': 5, 'cost': 20, 'graphic': f"{asset_path}/graphics/particles/flame/fire.png"},

View file

@ -2,7 +2,7 @@ import os
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../../..', 'assets') script_dir, '../..', 'assets')
weapon_data = { weapon_data = {
'sword': {'cooldown': 100, 'damage': 15, 'graphic': f"{asset_path}/graphics/weapons/sword/full.png"}, 'sword': {'cooldown': 100, 'damage': 15, 'graphic': f"{asset_path}/graphics/weapons/sword/full.png"},

View file

@ -10,7 +10,7 @@ class MagicPlayer:
self.animation_player = animation_player self.animation_player = animation_player
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../..', 'assets') script_dir, '..', 'assets')
# Sound Setup # Sound Setup
self.sounds = { self.sounds = {

View file

@ -10,7 +10,7 @@ class AnimationPlayer:
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../..', 'assets') script_dir, '..', 'assets')
self.frames = { self.frames = {
# magic # magic

View file

@ -9,7 +9,7 @@ class Weapon(pygame.sprite.Sprite):
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../..', 'assets') script_dir, '..', 'assets')
self.sprite_type = 'weapon' self.sprite_type = 'weapon'
direction = player._input.status.split('_')[0] direction = player._input.status.split('_')[0]

View file

@ -36,18 +36,15 @@ class InputHandler:
self.magic_swap_time = None self.magic_swap_time = None
# Setup Action Space # Setup Action Space
self.num_actions = 7
self.action = 10 self.action = 10
def check_input(self, speed, hitbox, obstacle_sprites, rect, player): def check_input(self, button, speed, hitbox, obstacle_sprites, rect, player):
self.action = 10 self.action = 10
if not self.attacking and self.can_move: if not self.attacking and self.can_move:
keys = pygame.key.get_pressed()
button = randint(0, 4)
self.move_time = pygame.time.get_ticks() self.move_time = pygame.time.get_ticks()
# Movement Input # Movement Input

View file

@ -19,7 +19,7 @@ class AnimationHandler:
def import_assets(self, position): def import_assets(self, position):
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../../..', 'assets', 'graphics') script_dir, '../..', 'assets', 'graphics')
if self.sprite_type == 'player': if self.sprite_type == 'player':

View file

@ -9,7 +9,7 @@ class AudioHandler:
def __init__(self, sprite_type, monster_name=None): def __init__(self, sprite_type, monster_name=None):
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../../..', 'assets', 'audio') script_dir, '../..', 'assets', 'audio')
if sprite_type == 'player': if sprite_type == 'player':
pass pass

View file

@ -33,7 +33,7 @@ class CombatHandler:
# Import Sounds # Import Sounds
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../../..', 'assets', 'audio') script_dir, '../..', 'assets', 'audio')
self.weapon_attack_sound = pygame.mixer.Sound( self.weapon_attack_sound = pygame.mixer.Sound(
f"{asset_path}/sword.wav") f"{asset_path}/sword.wav")

View file

@ -17,6 +17,7 @@ class Enemy(pygame.sprite.Sprite):
self.name = name self.name = name
self.visible_sprites = visible_sprites self.visible_sprites = visible_sprites
self.position = position
# Setup Graphics # Setup Graphics
self.audio = AudioHandler(self.sprite_type, self.name) self.audio = AudioHandler(self.sprite_type, self.name)
self.animation_player = AnimationPlayer() self.animation_player = AnimationPlayer()

View file

@ -13,7 +13,7 @@ class Observer(pygame.sprite.Sprite):
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../..', 'assets') script_dir, '..', 'assets')
self.image = pygame.image.load( self.image = pygame.image.load(
f"{asset_path}/graphics/observer.png").convert_alpha() f"{asset_path}/graphics/observer.png").convert_alpha()

View file

@ -10,14 +10,18 @@ from .components.animaton import AnimationHandler
from effects.particle_effects import AnimationPlayer from effects.particle_effects import AnimationPlayer
from agent.agent import Agent
class Player(pygame.sprite.Sprite): class Player(pygame.sprite.Sprite):
def __init__(self, position, groups, obstacle_sprites, visible_sprites, attack_sprites, attackable_sprites, role): def __init__(self, position, groups, obstacle_sprites, visible_sprites, attack_sprites, attackable_sprites, role, player_id, extract_features, convert_features_to_tensor):
super().__init__(groups) super().__init__(groups)
# Setup Sprites # Setup Sprites
self.sprite_type = 'player' self.sprite_type = 'player'
self.status = 'down'
self.player_id = player_id
self.visible_sprites = visible_sprites self.visible_sprites = visible_sprites
self.attack_sprites = attack_sprites self.attack_sprites = attack_sprites
self.obstacle_sprites = obstacle_sprites self.obstacle_sprites = obstacle_sprites
@ -40,6 +44,14 @@ class Player(pygame.sprite.Sprite):
self.distance_direction_from_enemy = None self.distance_direction_from_enemy = None
# Setup AI
self.extract_features = extract_features
self.convert_features_to_tensor = convert_features_to_tensor
self.agent = Agent(input_dims=398, n_actions=self._input.num_actions)
self.state_tensor = None
self.action_tensor = None
self.reward_tensor = None
def get_status(self): def get_status(self):
if self._input.movement.direction.x == 0 and self._input.movement.direction.y == 0: if self._input.movement.direction.x == 0 and self._input.movement.direction.y == 0:
if 'idle' not in self.status and 'attack' not in self.status: if 'idle' not in self.status and 'attack' not in self.status:
@ -85,10 +97,41 @@ class Player(pygame.sprite.Sprite):
spell_damage = magic_data[self._input.combat.magic]['strength'] spell_damage = magic_data[self._input.combat.magic]['strength']
return (base_damage + spell_damage) return (base_damage + spell_damage)
def get_current_state(self):
pass
def is_dead(self):
if self.stats.health == 0:
self.stats.exp = -10
return True
else:
return False
def update(self): def update(self):
self.extract_features()
self.convert_features_to_tensor()
# Choose action based on current state
action, probs, value = self.agent.choose_action(self.state_tensor)
print(action)
# Apply chosen action
self._input.check_input(action, self.stats.speed, self.animation.hitbox,
self.obstacle_sprites, self.animation.rect, self)
done = self.is_dead()
self.extract_features()
self.convert_features_to_tensor()
self.agent.remember(self.state_tensor, self.action_tensor,
probs, value, self.reward_tensor, done)
if done:
self.agent.learn()
self.agent.memory.clear_memory()
# Refresh objects based on input # Refresh objects based on input
self._input.check_input(
self.stats.speed, self.animation.hitbox, self.obstacle_sprites, self.animation.rect, self)
self.status = self._input.status self.status = self._input.status
# Animate # Animate

View file

@ -2,7 +2,7 @@ import os
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../..', 'assets') script_dir, '..', 'assets')
# ui # ui
BAR_HEIGHT = 20 BAR_HEIGHT = 20

View file

@ -16,7 +16,7 @@ class Camera(pygame.sprite.Group):
# Creating the floor # Creating the floor
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
image_path = os.path.join( image_path = os.path.join(
script_dir, '../..', 'assets', 'graphics', 'tilemap', 'ground.png') script_dir, '..', 'assets', 'graphics', 'tilemap', 'ground.png')
self.floor_surf = pygame.image.load(image_path).convert() self.floor_surf = pygame.image.load(image_path).convert()
self.floor_rect = self.floor_surf.get_rect(topleft=(0, 0)) self.floor_rect = self.floor_surf.get_rect(topleft=(0, 0))

View file

@ -9,7 +9,6 @@ from utils.debug import debug
from utils.resource_loader import import_csv_layout, import_folder from utils.resource_loader import import_csv_layout, import_folder
from interface.ui import UI from interface.ui import UI
from interface.upgrade import Upgrade
from entities.observer import Observer from entities.observer import Observer
from entities.player import Player from entities.player import Player
@ -21,11 +20,16 @@ from .camera import Camera
class Level: class Level:
def __init__(self): def __init__(self, extract_features,
convert_features_to_tensor):
# General Settings # General Settings
self.game_paused = False self.game_paused = False
# AI setup
self.extract_features = extract_features
self.convert_features_to_tensor = convert_features_to_tensor
# Get display surface # Get display surface
self.display_surface = pygame.display.get_surface() self.display_surface = pygame.display.get_surface()
@ -37,18 +41,17 @@ class Level:
# Sprite setup and entity generation # Sprite setup and entity generation
self.create_map() self.create_map()
# UI setup
self.ui = UI()
# self.upgrade = Upgrade(self.player)
self.get_players_enemies() self.get_players_enemies()
self.get_distance_direction() self.get_distance_direction()
# UI setup
self.ui = UI()
def create_map(self): def create_map(self):
player_id = 0
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
asset_path = os.path.join( asset_path = os.path.join(
script_dir, '../..', 'assets') script_dir, '..', 'assets')
layouts = { layouts = {
'boundary': import_csv_layout(f"{asset_path}/map/FloorBlocks.csv"), 'boundary': import_csv_layout(f"{asset_path}/map/FloorBlocks.csv"),
'grass': import_csv_layout(f"{asset_path}/map/Grass.csv"), 'grass': import_csv_layout(f"{asset_path}/map/Grass.csv"),
@ -89,17 +92,20 @@ class Level:
elif col == '400': elif col == '400':
# Player Generation # Player Generation
Player( Player(
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'tank') (x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'tank', player_id, self.extract_features, self.convert_features_to_tensor)
player_id += 1
elif col == '401': elif col == '401':
# Player Generation # Player Generation
Player( Player(
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'warrior') (x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'warrior', player_id, self.extract_features, self.convert_features_to_tensor)
player_id += 1
elif col == '402': elif col == '402':
# Player Generation # Player Generation
Player( Player(
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'mage') (x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'mage', player_id, self.extract_features, self.convert_features_to_tensor)
player_id += 1
else: else:
# Monster Generation # Monster Generation
@ -167,6 +173,14 @@ class Level:
debug('v0.6') debug('v0.6')
for player in self.player_sprites:
if player.is_dead():
print(player.stats.health)
player.kill()
if self.player_sprites == []:
self.__init__()
if not self.game_paused: if not self.game_paused:
# Update the game # Update the game
for player in self.player_sprites: for player in self.player_sprites:
@ -177,11 +191,5 @@ class Level:
self.apply_damage_to_player() self.apply_damage_to_player()
self.visible_sprites.update() self.visible_sprites.update()
# self.visible_sprites.enemy_update(self.player)
# self.player_attack_logic()
else: else:
debug('PAUSED') debug('PAUSED')
for player in self.player_sprites:
if player.stats.health <= 0:
player.kill()

View file

@ -19,10 +19,11 @@ class Game:
pygame.display.set_caption('Pneuma') pygame.display.set_caption('Pneuma')
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
self.level = Level() self.level = Level(self.extract_features,
self.convert_features_to_tensor)
# Sound # Sound
main_sound = pygame.mixer.Sound('../assets/audio/main.ogg') main_sound = pygame.mixer.Sound('assets/audio/main.ogg')
main_sound.set_volume(0.4) main_sound.set_volume(0.4)
main_sound.play(loops=-1) main_sound.play(loops=-1)
@ -35,14 +36,17 @@ class Game:
for i, player in enumerate(self.level.player_sprites): for i, player in enumerate(self.level.player_sprites):
player_action_features = { player_action_features = {
"player_id": player.player_id,
"player_action": player._input.action "player_action": player._input.action
} }
player_reward_features = { player_reward_features = {
"player_id": player.player_id,
"player_exp": player.stats.exp "player_exp": player.stats.exp
} }
player_state_features = { player_state_features = {
"player_id": player.player_id,
"player_position": player.rect.center, "player_position": player.rect.center,
"player role": player.stats.role_id, "player role": player.stats.role_id,
"player_health": player.stats.health, "player_health": player.stats.health,
@ -80,10 +84,6 @@ class Game:
def convert_features_to_tensor(self): def convert_features_to_tensor(self):
self.state_tensors = []
self.action_tensors = []
self.reward_tensors = []
for features in self.state_features: for features in self.state_features:
info_array = [] info_array = []
@ -125,8 +125,9 @@ class Game:
state_tensor = torch.tensor( state_tensor = torch.tensor(
np.array(info_array, dtype=np.float32)) np.array(info_array, dtype=np.float32))
self.state_tensors.append(state_tensor) for player in self.level.player_sprites:
if player.player_id == features["player_id"]:
player.state_tensor = state_tensor
for features in self.action_features: for features in self.action_features:
info_array = [] info_array = []
@ -139,7 +140,9 @@ class Game:
action_tensor = torch.tensor( action_tensor = torch.tensor(
np.array(action_info, dtype=np.float32)) np.array(action_info, dtype=np.float32))
self.action_tensors.append(action_tensor) for player in self.level.player_sprites:
if player.player_id == features["player_id"]:
player.action_tensor = action_tensor
for features in self.reward_features: for features in self.reward_features:
info_array = [] info_array = []
@ -152,7 +155,9 @@ class Game:
reward_tensor = torch.tensor( reward_tensor = torch.tensor(
np.array(reward_info, dtype=np.float32)) np.array(reward_info, dtype=np.float32))
self.reward_tensors.append(reward_tensor) for player in self.level.player_sprites:
if player.player_id == features["player_id"]:
player.reward_tensor = reward_tensor
def run(self): def run(self):
@ -166,10 +171,8 @@ class Game:
self.screen.fill(WATER_COLOR) self.screen.fill(WATER_COLOR)
self.extract_features() self.level.run(who='observer')
self.convert_features_to_tensor()
self.level.run('observer')
pygame.display.update() pygame.display.update()
self.clock.tick(FPS) self.clock.tick(FPS)
@ -179,9 +182,4 @@ if __name__ == '__main__':
game = Game() game = Game()
for i in range(0, 10000): for i in range(0, 10000):
game.run() game.run()
game.extract_features() print(i)
game.convert_features_to_tensor()
if i == 100:
print(game.reward_tensors)
print(game.action_tensors)
print(game.state_tensors)

View file

@ -2,6 +2,7 @@ import pygame
from csv import reader from csv import reader
from os import walk from os import walk
def import_csv_layout(path): def import_csv_layout(path):
terrain_map = [] terrain_map = []
with open(path) as level_map: with open(path) as level_map:
@ -10,6 +11,7 @@ def import_csv_layout(path):
terrain_map.append(list(row)) terrain_map.append(list(row))
return terrain_map return terrain_map
def import_folder(path): def import_folder(path):
surface_list = [] surface_list = []