Implemented (badly) agent
This commit is contained in:
parent
d20f46bb9d
commit
b4a6e99fce
38 changed files with 200 additions and 158 deletions
|
@ -1,14 +1,12 @@
|
|||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch as T
|
||||
|
||||
from numpy.random import default_rng
|
||||
|
||||
from rl.brain import ActorNetwork, CriticNetwork, PPOMemory
|
||||
from .brain import ActorNetwork, CriticNetwork, PPOMemory
|
||||
|
||||
|
||||
class Agent:
|
||||
|
||||
def __init__(self, n_actions, input_dims, gamma = 0.99, alpha = 0.0003, policy_clip = 0.2, batch_size = 64, N=2048, n_epochs = 10, gae_lambda = 0.95):
|
||||
def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003, policy_clip=0.2, batch_size=64, N=2048, n_epochs=10, gae_lambda=0.95):
|
||||
|
||||
self.gamma = gamma
|
||||
self.policy_clip = policy_clip
|
||||
|
@ -33,13 +31,13 @@ class Agent:
|
|||
print('... done ...')
|
||||
|
||||
def load_models(self):
|
||||
print('... loadng models ...')
|
||||
print('... loading models ...')
|
||||
self.actor.load_checkpoint()
|
||||
self.critic.load_checkpoint()
|
||||
print('.. done ...')
|
||||
|
||||
def choose_action(self, observation):
|
||||
state = T.tensor([observation], dtype = T.float).to(self.actor.device)
|
||||
state = observation.to(self.actor.device, dtype=T.float)
|
||||
|
||||
dist = self.actor(state)
|
||||
value = self.critic(state)
|
||||
|
@ -62,15 +60,19 @@ class Agent:
|
|||
discount = 1
|
||||
a_t = 0
|
||||
for k in range(t, len(reward_arr)-1):
|
||||
a_t += discount*(reward_arr[k] + self.gamma*values[k+1]*(1-int(dones_arr[k])) - values[k])
|
||||
a_t += discount * \
|
||||
(reward_arr[k] + self.gamma*values[k+1]
|
||||
* (1-int(done_arr[k])) - values[k])
|
||||
discount *= self.gamma * self.gae_lambda
|
||||
advantage[t] = a_t
|
||||
advantage = T.tensor(Advantage).to(self.actor.device)
|
||||
advantage = T.tensor(advantage).to(self.actor.device)
|
||||
|
||||
values = T.tensor(values).to(self.actor.device)
|
||||
for batch in batches:
|
||||
states = T.tensor(state_arr[batch], dtype = T.float).to(self.actor.device)
|
||||
old_probs = T.tensor(old_probs_arr[batch]).to(self.actor.device)
|
||||
states = T.tensor(state_arr[batch], dtype=T.float).to(
|
||||
self.actor.device)
|
||||
old_probs = T.tensor(old_probs_arr[batch]).to(
|
||||
self.actor.device)
|
||||
actions = T.tensor(action_arr[batch]).to(self.actor.device)
|
||||
|
||||
dist = self.actor(states)
|
||||
|
@ -81,8 +83,10 @@ class Agent:
|
|||
new_probs = dist.log_prob(actions)
|
||||
prob_ratio = new_probs.exp() / old_probs.exp()
|
||||
weighted_probs = advantage[batch] * prob_ratio
|
||||
weighted_clipped_probs = T.clamp(prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch]
|
||||
actor_loss = -T.min(weighted_probs, weighted_clipped_probs).mean()
|
||||
weighted_clipped_probs = T.clamp(
|
||||
prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch]
|
||||
actor_loss = -T.min(weighted_probs,
|
||||
weighted_clipped_probs).mean()
|
||||
|
||||
returns = advantage[batch] + values[batch]
|
||||
critic_loss = (returns - critic_value)**2
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
from torch.distributions.categorical import Categorical
|
||||
|
||||
|
||||
class PPOMemory:
|
||||
def __init__(self, batch_size):
|
||||
self.states = []
|
||||
|
@ -47,6 +48,7 @@ class PPOMemory:
|
|||
self.rewards = []
|
||||
self.dones = []
|
||||
|
||||
|
||||
class ActorNetwork(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
|
||||
|
@ -54,17 +56,18 @@ class ActorNetwork(nn.Module):
|
|||
|
||||
self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
|
||||
self.actor = nn.Sequential(
|
||||
nn.Linear(len(input_dim), fc1_dims),
|
||||
nn.Linear(input_dim, fc1_dims),
|
||||
nn.ReLU(),
|
||||
nn.Linear(fc1_dims, fc2_dims),
|
||||
nn.ReLU(),
|
||||
nn.Linear(fc2_dims, len(output_dim)),
|
||||
nn.Linear(fc2_dims, output_dim),
|
||||
nn.Softmax(dim=-1)
|
||||
)
|
||||
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
|
||||
|
||||
self.device = T.device('cuda:0' if T.cuda.is_available() else ('mps' if T.backends.mps.is_available() else 'cpu'))
|
||||
self.device = T.device('cuda:0' if T.cuda.is_available() else (
|
||||
'mps' if T.backends.mps.is_available() else 'cpu'))
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
|
@ -80,6 +83,7 @@ class ActorNetwork(nn.Module):
|
|||
def load_checkpoint(self):
|
||||
self.load_state_dict(T.load(self.checkpoint_file))
|
||||
|
||||
|
||||
class CriticNetwork(nn.Module):
|
||||
|
||||
def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
|
||||
|
@ -87,7 +91,7 @@ class CriticNetwork(nn.Module):
|
|||
|
||||
self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
|
||||
self.critic = nn.Sequential(
|
||||
nn.Linear(len(input_dims), fc1_dims),
|
||||
nn.Linear(input_dims, fc1_dims),
|
||||
nn.ReLU(),
|
||||
nn.Linear(fc1_dims, fc2_dims),
|
||||
nn.ReLU(),
|
||||
|
@ -95,12 +99,13 @@ class CriticNetwork(nn.Module):
|
|||
)
|
||||
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
|
||||
self.device = T.device('cuda:0' if T.cuda.is_available() else ('mps' if T.backends.mps.is_available() else 'cpu'))
|
||||
self.device = T.device('cuda:0' if T.cuda.is_available() else (
|
||||
'mps' if T.backends.mps.is_available() else 'cpu'))
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, state):
|
||||
vale = self.critic(state)
|
||||
value = self.critic(state)
|
||||
|
||||
return value
|
||||
|
||||
|
@ -109,19 +114,3 @@ class CriticNetwork(nn.Module):
|
|||
|
||||
def load_checkpoint(self):
|
||||
self.load_state_dict(T.load(self.checkpoint_file))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../../..', 'assets')
|
||||
script_dir, '../..', 'assets')
|
||||
|
||||
monster_data = {
|
||||
'squid': {'id': 1, 'health': 100, 'exp': 100, 'attack': 20, 'attack_type': 'slash', 'attack_sound': f'{asset_path}/audio/attack/slash.wav', 'speed': 3, 'knockback': 20, 'attack_radius': 80, 'notice_radius': 360},
|
|
@ -2,7 +2,7 @@ import os
|
|||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../../..', 'assets')
|
||||
script_dir, '../..', 'assets')
|
||||
|
||||
magic_data = {
|
||||
'flame': {'strength': 5, 'cost': 20, 'graphic': f"{asset_path}/graphics/particles/flame/fire.png"},
|
|
@ -2,7 +2,7 @@ import os
|
|||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../../..', 'assets')
|
||||
script_dir, '../..', 'assets')
|
||||
|
||||
weapon_data = {
|
||||
'sword': {'cooldown': 100, 'damage': 15, 'graphic': f"{asset_path}/graphics/weapons/sword/full.png"},
|
|
@ -10,7 +10,7 @@ class MagicPlayer:
|
|||
self.animation_player = animation_player
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../..', 'assets')
|
||||
script_dir, '..', 'assets')
|
||||
|
||||
# Sound Setup
|
||||
self.sounds = {
|
|
@ -10,7 +10,7 @@ class AnimationPlayer:
|
|||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../..', 'assets')
|
||||
script_dir, '..', 'assets')
|
||||
|
||||
self.frames = {
|
||||
# magic
|
|
@ -9,7 +9,7 @@ class Weapon(pygame.sprite.Sprite):
|
|||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../..', 'assets')
|
||||
script_dir, '..', 'assets')
|
||||
|
||||
self.sprite_type = 'weapon'
|
||||
direction = player._input.status.split('_')[0]
|
|
@ -36,18 +36,15 @@ class InputHandler:
|
|||
self.magic_swap_time = None
|
||||
|
||||
# Setup Action Space
|
||||
self.num_actions = 7
|
||||
self.action = 10
|
||||
|
||||
def check_input(self, speed, hitbox, obstacle_sprites, rect, player):
|
||||
def check_input(self, button, speed, hitbox, obstacle_sprites, rect, player):
|
||||
|
||||
self.action = 10
|
||||
|
||||
if not self.attacking and self.can_move:
|
||||
|
||||
keys = pygame.key.get_pressed()
|
||||
|
||||
button = randint(0, 4)
|
||||
|
||||
self.move_time = pygame.time.get_ticks()
|
||||
|
||||
# Movement Input
|
|
@ -19,7 +19,7 @@ class AnimationHandler:
|
|||
def import_assets(self, position):
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../../..', 'assets', 'graphics')
|
||||
script_dir, '../..', 'assets', 'graphics')
|
||||
|
||||
if self.sprite_type == 'player':
|
||||
|
|
@ -9,7 +9,7 @@ class AudioHandler:
|
|||
def __init__(self, sprite_type, monster_name=None):
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../../..', 'assets', 'audio')
|
||||
script_dir, '../..', 'assets', 'audio')
|
||||
|
||||
if sprite_type == 'player':
|
||||
pass
|
|
@ -33,7 +33,7 @@ class CombatHandler:
|
|||
# Import Sounds
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../../..', 'assets', 'audio')
|
||||
script_dir, '../..', 'assets', 'audio')
|
||||
|
||||
self.weapon_attack_sound = pygame.mixer.Sound(
|
||||
f"{asset_path}/sword.wav")
|
|
@ -17,6 +17,7 @@ class Enemy(pygame.sprite.Sprite):
|
|||
self.name = name
|
||||
self.visible_sprites = visible_sprites
|
||||
|
||||
self.position = position
|
||||
# Setup Graphics
|
||||
self.audio = AudioHandler(self.sprite_type, self.name)
|
||||
self.animation_player = AnimationPlayer()
|
|
@ -13,7 +13,7 @@ class Observer(pygame.sprite.Sprite):
|
|||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../..', 'assets')
|
||||
script_dir, '..', 'assets')
|
||||
|
||||
self.image = pygame.image.load(
|
||||
f"{asset_path}/graphics/observer.png").convert_alpha()
|
|
@ -10,14 +10,18 @@ from .components.animaton import AnimationHandler
|
|||
|
||||
from effects.particle_effects import AnimationPlayer
|
||||
|
||||
from agent.agent import Agent
|
||||
|
||||
|
||||
class Player(pygame.sprite.Sprite):
|
||||
|
||||
def __init__(self, position, groups, obstacle_sprites, visible_sprites, attack_sprites, attackable_sprites, role):
|
||||
def __init__(self, position, groups, obstacle_sprites, visible_sprites, attack_sprites, attackable_sprites, role, player_id, extract_features, convert_features_to_tensor):
|
||||
super().__init__(groups)
|
||||
|
||||
# Setup Sprites
|
||||
self.sprite_type = 'player'
|
||||
self.status = 'down'
|
||||
self.player_id = player_id
|
||||
self.visible_sprites = visible_sprites
|
||||
self.attack_sprites = attack_sprites
|
||||
self.obstacle_sprites = obstacle_sprites
|
||||
|
@ -40,6 +44,14 @@ class Player(pygame.sprite.Sprite):
|
|||
|
||||
self.distance_direction_from_enemy = None
|
||||
|
||||
# Setup AI
|
||||
self.extract_features = extract_features
|
||||
self.convert_features_to_tensor = convert_features_to_tensor
|
||||
self.agent = Agent(input_dims=398, n_actions=self._input.num_actions)
|
||||
self.state_tensor = None
|
||||
self.action_tensor = None
|
||||
self.reward_tensor = None
|
||||
|
||||
def get_status(self):
|
||||
if self._input.movement.direction.x == 0 and self._input.movement.direction.y == 0:
|
||||
if 'idle' not in self.status and 'attack' not in self.status:
|
||||
|
@ -85,10 +97,41 @@ class Player(pygame.sprite.Sprite):
|
|||
spell_damage = magic_data[self._input.combat.magic]['strength']
|
||||
return (base_damage + spell_damage)
|
||||
|
||||
def get_current_state(self):
|
||||
pass
|
||||
|
||||
def is_dead(self):
|
||||
if self.stats.health == 0:
|
||||
self.stats.exp = -10
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def update(self):
|
||||
self.extract_features()
|
||||
self.convert_features_to_tensor()
|
||||
|
||||
# Choose action based on current state
|
||||
action, probs, value = self.agent.choose_action(self.state_tensor)
|
||||
|
||||
print(action)
|
||||
# Apply chosen action
|
||||
self._input.check_input(action, self.stats.speed, self.animation.hitbox,
|
||||
self.obstacle_sprites, self.animation.rect, self)
|
||||
|
||||
done = self.is_dead()
|
||||
|
||||
self.extract_features()
|
||||
self.convert_features_to_tensor()
|
||||
|
||||
self.agent.remember(self.state_tensor, self.action_tensor,
|
||||
probs, value, self.reward_tensor, done)
|
||||
|
||||
if done:
|
||||
self.agent.learn()
|
||||
self.agent.memory.clear_memory()
|
||||
|
||||
# Refresh objects based on input
|
||||
self._input.check_input(
|
||||
self.stats.speed, self.animation.hitbox, self.obstacle_sprites, self.animation.rect, self)
|
||||
self.status = self._input.status
|
||||
|
||||
# Animate
|
|
@ -2,7 +2,7 @@ import os
|
|||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../..', 'assets')
|
||||
script_dir, '..', 'assets')
|
||||
|
||||
# ui
|
||||
BAR_HEIGHT = 20
|
|
@ -16,7 +16,7 @@ class Camera(pygame.sprite.Group):
|
|||
# Creating the floor
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
image_path = os.path.join(
|
||||
script_dir, '../..', 'assets', 'graphics', 'tilemap', 'ground.png')
|
||||
script_dir, '..', 'assets', 'graphics', 'tilemap', 'ground.png')
|
||||
|
||||
self.floor_surf = pygame.image.load(image_path).convert()
|
||||
self.floor_rect = self.floor_surf.get_rect(topleft=(0, 0))
|
|
@ -9,7 +9,6 @@ from utils.debug import debug
|
|||
from utils.resource_loader import import_csv_layout, import_folder
|
||||
|
||||
from interface.ui import UI
|
||||
from interface.upgrade import Upgrade
|
||||
|
||||
from entities.observer import Observer
|
||||
from entities.player import Player
|
||||
|
@ -21,11 +20,16 @@ from .camera import Camera
|
|||
|
||||
class Level:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, extract_features,
|
||||
convert_features_to_tensor):
|
||||
|
||||
# General Settings
|
||||
self.game_paused = False
|
||||
|
||||
# AI setup
|
||||
self.extract_features = extract_features
|
||||
self.convert_features_to_tensor = convert_features_to_tensor
|
||||
|
||||
# Get display surface
|
||||
self.display_surface = pygame.display.get_surface()
|
||||
|
||||
|
@ -37,18 +41,17 @@ class Level:
|
|||
|
||||
# Sprite setup and entity generation
|
||||
self.create_map()
|
||||
|
||||
# UI setup
|
||||
self.ui = UI()
|
||||
# self.upgrade = Upgrade(self.player)
|
||||
|
||||
self.get_players_enemies()
|
||||
self.get_distance_direction()
|
||||
|
||||
# UI setup
|
||||
self.ui = UI()
|
||||
|
||||
def create_map(self):
|
||||
player_id = 0
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '../..', 'assets')
|
||||
script_dir, '..', 'assets')
|
||||
layouts = {
|
||||
'boundary': import_csv_layout(f"{asset_path}/map/FloorBlocks.csv"),
|
||||
'grass': import_csv_layout(f"{asset_path}/map/Grass.csv"),
|
||||
|
@ -89,17 +92,20 @@ class Level:
|
|||
elif col == '400':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'tank')
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'tank', player_id, self.extract_features, self.convert_features_to_tensor)
|
||||
player_id += 1
|
||||
|
||||
elif col == '401':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'warrior')
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'warrior', player_id, self.extract_features, self.convert_features_to_tensor)
|
||||
player_id += 1
|
||||
|
||||
elif col == '402':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'mage')
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'mage', player_id, self.extract_features, self.convert_features_to_tensor)
|
||||
player_id += 1
|
||||
|
||||
else:
|
||||
# Monster Generation
|
||||
|
@ -167,6 +173,14 @@ class Level:
|
|||
|
||||
debug('v0.6')
|
||||
|
||||
for player in self.player_sprites:
|
||||
if player.is_dead():
|
||||
print(player.stats.health)
|
||||
player.kill()
|
||||
|
||||
if self.player_sprites == []:
|
||||
self.__init__()
|
||||
|
||||
if not self.game_paused:
|
||||
# Update the game
|
||||
for player in self.player_sprites:
|
||||
|
@ -177,11 +191,5 @@ class Level:
|
|||
self.apply_damage_to_player()
|
||||
self.visible_sprites.update()
|
||||
|
||||
# self.visible_sprites.enemy_update(self.player)
|
||||
# self.player_attack_logic()
|
||||
else:
|
||||
debug('PAUSED')
|
||||
|
||||
for player in self.player_sprites:
|
||||
if player.stats.health <= 0:
|
||||
player.kill()
|
|
@ -19,10 +19,11 @@ class Game:
|
|||
pygame.display.set_caption('Pneuma')
|
||||
self.clock = pygame.time.Clock()
|
||||
|
||||
self.level = Level()
|
||||
self.level = Level(self.extract_features,
|
||||
self.convert_features_to_tensor)
|
||||
|
||||
# Sound
|
||||
main_sound = pygame.mixer.Sound('../assets/audio/main.ogg')
|
||||
main_sound = pygame.mixer.Sound('assets/audio/main.ogg')
|
||||
main_sound.set_volume(0.4)
|
||||
main_sound.play(loops=-1)
|
||||
|
||||
|
@ -35,14 +36,17 @@ class Game:
|
|||
for i, player in enumerate(self.level.player_sprites):
|
||||
|
||||
player_action_features = {
|
||||
"player_id": player.player_id,
|
||||
"player_action": player._input.action
|
||||
}
|
||||
|
||||
player_reward_features = {
|
||||
"player_id": player.player_id,
|
||||
"player_exp": player.stats.exp
|
||||
}
|
||||
|
||||
player_state_features = {
|
||||
"player_id": player.player_id,
|
||||
"player_position": player.rect.center,
|
||||
"player role": player.stats.role_id,
|
||||
"player_health": player.stats.health,
|
||||
|
@ -80,10 +84,6 @@ class Game:
|
|||
|
||||
def convert_features_to_tensor(self):
|
||||
|
||||
self.state_tensors = []
|
||||
self.action_tensors = []
|
||||
self.reward_tensors = []
|
||||
|
||||
for features in self.state_features:
|
||||
info_array = []
|
||||
|
||||
|
@ -125,8 +125,9 @@ class Game:
|
|||
state_tensor = torch.tensor(
|
||||
np.array(info_array, dtype=np.float32))
|
||||
|
||||
self.state_tensors.append(state_tensor)
|
||||
|
||||
for player in self.level.player_sprites:
|
||||
if player.player_id == features["player_id"]:
|
||||
player.state_tensor = state_tensor
|
||||
|
||||
for features in self.action_features:
|
||||
info_array = []
|
||||
|
@ -139,7 +140,9 @@ class Game:
|
|||
action_tensor = torch.tensor(
|
||||
np.array(action_info, dtype=np.float32))
|
||||
|
||||
self.action_tensors.append(action_tensor)
|
||||
for player in self.level.player_sprites:
|
||||
if player.player_id == features["player_id"]:
|
||||
player.action_tensor = action_tensor
|
||||
|
||||
for features in self.reward_features:
|
||||
info_array = []
|
||||
|
@ -152,7 +155,9 @@ class Game:
|
|||
reward_tensor = torch.tensor(
|
||||
np.array(reward_info, dtype=np.float32))
|
||||
|
||||
self.reward_tensors.append(reward_tensor)
|
||||
for player in self.level.player_sprites:
|
||||
if player.player_id == features["player_id"]:
|
||||
player.reward_tensor = reward_tensor
|
||||
|
||||
def run(self):
|
||||
|
||||
|
@ -166,10 +171,8 @@ class Game:
|
|||
|
||||
self.screen.fill(WATER_COLOR)
|
||||
|
||||
self.extract_features()
|
||||
self.convert_features_to_tensor()
|
||||
self.level.run(who='observer')
|
||||
|
||||
self.level.run('observer')
|
||||
pygame.display.update()
|
||||
self.clock.tick(FPS)
|
||||
|
||||
|
@ -179,9 +182,4 @@ if __name__ == '__main__':
|
|||
game = Game()
|
||||
for i in range(0, 10000):
|
||||
game.run()
|
||||
game.extract_features()
|
||||
game.convert_features_to_tensor()
|
||||
if i == 100:
|
||||
print(game.reward_tensors)
|
||||
print(game.action_tensors)
|
||||
print(game.state_tensors)
|
||||
print(i)
|
|
@ -2,6 +2,7 @@ import pygame
|
|||
from csv import reader
|
||||
from os import walk
|
||||
|
||||
|
||||
def import_csv_layout(path):
|
||||
terrain_map = []
|
||||
with open(path) as level_map:
|
||||
|
@ -10,6 +11,7 @@ def import_csv_layout(path):
|
|||
terrain_map.append(list(row))
|
||||
return terrain_map
|
||||
|
||||
|
||||
def import_folder(path):
|
||||
surface_list = []
|
||||
|
Loading…
Reference in a new issue