Hopefully implemented PPO
This commit is contained in:
parent
b4a6e99fce
commit
115b2e4151
13 changed files with 211 additions and 185 deletions
0
agents/ppo/__init__.py
Normal file
0
agents/ppo/__init__.py
Normal file
|
@ -6,7 +6,9 @@ from .brain import ActorNetwork, CriticNetwork, PPOMemory
|
|||
|
||||
class Agent:
|
||||
|
||||
def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003, policy_clip=0.2, batch_size=64, N=2048, n_epochs=10, gae_lambda=0.95):
|
||||
def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003,
|
||||
policy_clip=0.2, batch_size=64, N=2048, n_epochs=10,
|
||||
gae_lambda=0.95):
|
||||
|
||||
self.gamma = gamma
|
||||
self.policy_clip = policy_clip
|
||||
|
@ -37,7 +39,7 @@ class Agent:
|
|||
print('.. done ...')
|
||||
|
||||
def choose_action(self, observation):
|
||||
state = observation.to(self.actor.device, dtype=T.float)
|
||||
state = T.tensor([observation], dtype=T.float).to(self.actor.device)
|
||||
|
||||
dist = self.actor(state)
|
||||
value = self.critic(state)
|
||||
|
@ -46,12 +48,11 @@ class Agent:
|
|||
probs = T.squeeze(dist.log_prob(action)).item()
|
||||
action = T.squeeze(action).item()
|
||||
value = T.squeeze(value).item()
|
||||
|
||||
return action, probs, value
|
||||
|
||||
def learn(self):
|
||||
for _ in range(self.n_epochs):
|
||||
state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, done_arr, batches = self.memory.generate_batches()
|
||||
state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.generate_batches()
|
||||
|
||||
values = vals_arr
|
||||
advantage = np.zeros(len(reward_arr), dtype=np.float32)
|
||||
|
@ -62,12 +63,13 @@ class Agent:
|
|||
for k in range(t, len(reward_arr)-1):
|
||||
a_t += discount * \
|
||||
(reward_arr[k] + self.gamma*values[k+1]
|
||||
* (1-int(done_arr[k])) - values[k])
|
||||
* (1-int(dones_arr[k])) - values[k])
|
||||
discount *= self.gamma * self.gae_lambda
|
||||
advantage[t] = a_t
|
||||
advantage = T.tensor(advantage).to(self.actor.device)
|
||||
|
||||
values = T.tensor(values).to(self.actor.device)
|
||||
|
||||
for batch in batches:
|
||||
states = T.tensor(state_arr[batch], dtype=T.float).to(
|
||||
self.actor.device)
|
|
@ -35,6 +35,7 @@ class PPOMemory:
|
|||
|
||||
def store_memory(self, state, action, probs, vals, reward, done):
|
||||
self.states.append(state)
|
||||
self.actions.append(action)
|
||||
self.probs.append(probs)
|
||||
self.vals.append(vals)
|
||||
self.rewards.append(reward)
|
||||
|
@ -55,6 +56,7 @@ class ActorNetwork(nn.Module):
|
|||
super(ActorNetwork, self).__init__()
|
||||
|
||||
self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
|
||||
|
||||
self.actor = nn.Sequential(
|
||||
nn.Linear(input_dim, fc1_dims),
|
||||
nn.ReLU(),
|
||||
|
@ -90,6 +92,7 @@ class CriticNetwork(nn.Module):
|
|||
super(CriticNetwork, self).__init__()
|
||||
|
||||
self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
|
||||
|
||||
self.critic = nn.Sequential(
|
||||
nn.Linear(input_dims, fc1_dims),
|
||||
nn.ReLU(),
|
|
@ -18,9 +18,9 @@
|
|||
-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,500,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,395,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
|
|
|
|
@ -1,7 +1,7 @@
|
|||
# game setup
|
||||
WIDTH = 1280
|
||||
HEIGHT = 720
|
||||
FPS = 500
|
||||
FPS = 1000
|
||||
TILESIZE = 64
|
||||
HITBOX_OFFSET = {
|
||||
'player': (-6, -26),
|
||||
|
|
|
@ -36,7 +36,7 @@ class InputHandler:
|
|||
self.magic_swap_time = None
|
||||
|
||||
# Setup Action Space
|
||||
self.num_actions = 7
|
||||
self.possible_actions = [0, 1, 2, 3, 4, 5]
|
||||
self.action = 10
|
||||
|
||||
def check_input(self, button, speed, hitbox, obstacle_sprites, rect, player):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import pygame
|
||||
import numpy as np
|
||||
from random import randint
|
||||
|
||||
from configs.game.weapon_config import weapon_data
|
||||
|
@ -10,12 +11,12 @@ from .components.animaton import AnimationHandler
|
|||
|
||||
from effects.particle_effects import AnimationPlayer
|
||||
|
||||
from agent.agent import Agent
|
||||
from agents.ppo.agent import Agent
|
||||
|
||||
|
||||
class Player(pygame.sprite.Sprite):
|
||||
|
||||
def __init__(self, position, groups, obstacle_sprites, visible_sprites, attack_sprites, attackable_sprites, role, player_id, extract_features, convert_features_to_tensor):
|
||||
def __init__(self, position, groups, obstacle_sprites, visible_sprites, attack_sprites, attackable_sprites, role, player_id):
|
||||
super().__init__(groups)
|
||||
|
||||
# Setup Sprites
|
||||
|
@ -44,14 +45,6 @@ class Player(pygame.sprite.Sprite):
|
|||
|
||||
self.distance_direction_from_enemy = None
|
||||
|
||||
# Setup AI
|
||||
self.extract_features = extract_features
|
||||
self.convert_features_to_tensor = convert_features_to_tensor
|
||||
self.agent = Agent(input_dims=398, n_actions=self._input.num_actions)
|
||||
self.state_tensor = None
|
||||
self.action_tensor = None
|
||||
self.reward_tensor = None
|
||||
|
||||
def get_status(self):
|
||||
if self._input.movement.direction.x == 0 and self._input.movement.direction.y == 0:
|
||||
if 'idle' not in self.status and 'attack' not in self.status:
|
||||
|
@ -98,7 +91,54 @@ class Player(pygame.sprite.Sprite):
|
|||
return (base_damage + spell_damage)
|
||||
|
||||
def get_current_state(self):
|
||||
pass
|
||||
|
||||
self.action_features = [self._input.action]
|
||||
self.reward_features = [self.stats.exp]
|
||||
self.state_features = [
|
||||
self.stats.role_id,
|
||||
self.rect.center[0],
|
||||
self.rect.center[1],
|
||||
self.stats.health,
|
||||
self.stats.energy,
|
||||
self.stats.attack,
|
||||
self.stats.magic,
|
||||
self.stats.speed,
|
||||
int(self._input.combat.vulnerable),
|
||||
int(self._input.can_move),
|
||||
int(self._input.attacking),
|
||||
int(self._input.can_rotate_weapon),
|
||||
int(self._input.can_swap_magic)
|
||||
]
|
||||
|
||||
enemy_states = []
|
||||
|
||||
for distance, direction, enemy in self.distance_direction_from_enemy:
|
||||
enemy_states.extend([
|
||||
distance,
|
||||
direction[0],
|
||||
direction[1],
|
||||
enemy.stats.monster_id,
|
||||
0 if enemy.animation.status == "idle" else (
|
||||
1 if enemy.animation.status == "move" else 2),
|
||||
enemy.stats.health,
|
||||
enemy.stats.attack,
|
||||
enemy.stats.speed,
|
||||
enemy.stats.exp,
|
||||
enemy.stats.attack_radius,
|
||||
enemy.stats.notice_radius
|
||||
])
|
||||
self.state_features.extend(enemy_states)
|
||||
|
||||
def setup_agent(self):
|
||||
# Setup AI
|
||||
self.get_current_state()
|
||||
self.agent = Agent(
|
||||
input_dims=len(self.state_features), n_actions=len(self._input.possible_actions), batch_size=5, n_epochs=4)
|
||||
self.score = 0
|
||||
self.learn_iters = 0
|
||||
|
||||
self.n_steps = 0
|
||||
self.N = 20
|
||||
|
||||
def is_dead(self):
|
||||
if self.stats.health == 0:
|
||||
|
@ -108,28 +148,32 @@ class Player(pygame.sprite.Sprite):
|
|||
return False
|
||||
|
||||
def update(self):
|
||||
self.extract_features()
|
||||
self.convert_features_to_tensor()
|
||||
|
||||
# Get the current state
|
||||
self.get_current_state()
|
||||
|
||||
# Choose action based on current state
|
||||
action, probs, value = self.agent.choose_action(self.state_tensor)
|
||||
action, probs, value = self.agent.choose_action(self.state_features)
|
||||
|
||||
print(action)
|
||||
self.n_steps += 1
|
||||
# Apply chosen action
|
||||
self._input.check_input(action, self.stats.speed, self.animation.hitbox,
|
||||
self.obstacle_sprites, self.animation.rect, self)
|
||||
|
||||
done = self.is_dead()
|
||||
self.done = self.is_dead()
|
||||
|
||||
self.extract_features()
|
||||
self.convert_features_to_tensor()
|
||||
self.score = self.stats.exp
|
||||
self.agent.remember(self.state_features, action,
|
||||
probs, value, self.stats.exp, self.done)
|
||||
|
||||
self.agent.remember(self.state_tensor, self.action_tensor,
|
||||
probs, value, self.reward_tensor, done)
|
||||
|
||||
if done:
|
||||
if self.n_steps % self.N == 0:
|
||||
self.agent.learn()
|
||||
self.learn_iters += 1
|
||||
|
||||
self.get_current_state()
|
||||
|
||||
if self.done:
|
||||
self.agent.learn()
|
||||
self.agent.memory.clear_memory()
|
||||
|
||||
# Refresh objects based on input
|
||||
self.status = self._input.status
|
||||
|
|
|
@ -20,15 +20,11 @@ from .camera import Camera
|
|||
|
||||
class Level:
|
||||
|
||||
def __init__(self, extract_features,
|
||||
convert_features_to_tensor):
|
||||
def __init__(self):
|
||||
|
||||
# General Settings
|
||||
self.game_paused = False
|
||||
|
||||
# AI setup
|
||||
self.extract_features = extract_features
|
||||
self.convert_features_to_tensor = convert_features_to_tensor
|
||||
self.done = False
|
||||
|
||||
# Get display surface
|
||||
self.display_surface = pygame.display.get_surface()
|
||||
|
@ -43,6 +39,8 @@ class Level:
|
|||
self.create_map()
|
||||
self.get_players_enemies()
|
||||
self.get_distance_direction()
|
||||
for player in self.player_sprites:
|
||||
player.setup_agent()
|
||||
|
||||
# UI setup
|
||||
self.ui = UI()
|
||||
|
@ -85,26 +83,99 @@ class Level:
|
|||
|
||||
if style == 'entities':
|
||||
# The numbers represent their IDs in .csv files generated from TILED.
|
||||
if col == '395':
|
||||
if col == '500':
|
||||
self.observer = Observer(
|
||||
(x, y), [self.visible_sprites])
|
||||
|
||||
elif col == '400':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'tank', player_id, self.extract_features, self.convert_features_to_tensor)
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'tank', player_id)
|
||||
player_id += 1
|
||||
|
||||
elif col == '401':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'warrior', player_id, self.extract_features, self.convert_features_to_tensor)
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'warrior', player_id)
|
||||
player_id += 1
|
||||
|
||||
elif col == '402':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'mage', player_id, self.extract_features, self.convert_features_to_tensor)
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'mage', player_id)
|
||||
player_id += 1
|
||||
|
||||
else:
|
||||
# Monster Generation
|
||||
if col == '390':
|
||||
monster_name = 'bamboo'
|
||||
elif col == '391':
|
||||
monster_name = 'spirit'
|
||||
elif col == '392':
|
||||
monster_name = 'raccoon'
|
||||
else:
|
||||
monster_name = 'squid'
|
||||
|
||||
Enemy(monster_name, (x, y), [
|
||||
self.visible_sprites, self.attackable_sprites], self.visible_sprites, self.obstacle_sprites)
|
||||
|
||||
def reset_map(self):
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
asset_path = os.path.join(
|
||||
script_dir, '..', 'assets')
|
||||
layouts = {
|
||||
'boundary': import_csv_layout(f"{asset_path}/map/FloorBlocks.csv"),
|
||||
'grass': import_csv_layout(f"{asset_path}/map/Grass.csv"),
|
||||
'objects': import_csv_layout(f"{asset_path}/map/Objects.csv"),
|
||||
'entities': import_csv_layout(f"{asset_path}/map/Entities.csv")
|
||||
}
|
||||
|
||||
graphics = {
|
||||
'grass': import_folder(f"{asset_path}/graphics/grass"),
|
||||
'objects': import_folder(f"{asset_path}/graphics/objects")
|
||||
}
|
||||
|
||||
for style, layout in layouts.items():
|
||||
for row_index, row in enumerate(layout):
|
||||
for col_index, col in enumerate(row):
|
||||
if col != '-1':
|
||||
x = col_index * TILESIZE
|
||||
y = row_index * TILESIZE
|
||||
if style == 'boundary':
|
||||
Tile((x, y), [self.obstacle_sprites], 'invisible')
|
||||
|
||||
if style == 'grass':
|
||||
random_grass_image = choice(graphics['grass'])
|
||||
Tile((x, y), [self.visible_sprites, self.obstacle_sprites,
|
||||
self.attackable_sprites], 'grass', random_grass_image)
|
||||
|
||||
if style == 'objects':
|
||||
surf = graphics['objects'][int(col)]
|
||||
Tile((x, y), [self.visible_sprites,
|
||||
self.obstacle_sprites], 'object', surf)
|
||||
|
||||
if style == 'entities':
|
||||
# The numbers represent their IDs in .csv files generated from TILED.
|
||||
if col == '500':
|
||||
self.observer = Observer(
|
||||
(x, y), [self.visible_sprites])
|
||||
|
||||
elif col == '400':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'tank', player_id)
|
||||
player_id += 1
|
||||
|
||||
elif col == '401':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'warrior', player_id)
|
||||
player_id += 1
|
||||
|
||||
elif col == '402':
|
||||
# Player Generation
|
||||
Player(
|
||||
(x, y), [self.visible_sprites], self.obstacle_sprites, self.visible_sprites, self.attack_sprites, self.attackable_sprites, 'mage', player_id)
|
||||
player_id += 1
|
||||
|
||||
else:
|
||||
|
@ -175,11 +246,7 @@ class Level:
|
|||
|
||||
for player in self.player_sprites:
|
||||
if player.is_dead():
|
||||
print(player.stats.health)
|
||||
player.kill()
|
||||
|
||||
if self.player_sprites == []:
|
||||
self.__init__()
|
||||
self.done = True
|
||||
|
||||
if not self.game_paused:
|
||||
# Update the game
|
||||
|
|
184
main.py
184
main.py
|
@ -2,7 +2,7 @@ import sys
|
|||
import numpy as np
|
||||
import torch
|
||||
import pygame
|
||||
|
||||
from tqdm import tqdm
|
||||
from configs.system.window_config import WIDTH, HEIGHT, WATER_COLOR, FPS
|
||||
|
||||
from level.level import Level
|
||||
|
@ -19,146 +19,13 @@ class Game:
|
|||
pygame.display.set_caption('Pneuma')
|
||||
self.clock = pygame.time.Clock()
|
||||
|
||||
self.level = Level(self.extract_features,
|
||||
self.convert_features_to_tensor)
|
||||
self.level = Level()
|
||||
|
||||
# Sound
|
||||
main_sound = pygame.mixer.Sound('assets/audio/main.ogg')
|
||||
main_sound.set_volume(0.4)
|
||||
main_sound.play(loops=-1)
|
||||
|
||||
def extract_features(self):
|
||||
self.state_features = []
|
||||
self.reward_features = []
|
||||
self.action_features = []
|
||||
self.features = []
|
||||
|
||||
for i, player in enumerate(self.level.player_sprites):
|
||||
|
||||
player_action_features = {
|
||||
"player_id": player.player_id,
|
||||
"player_action": player._input.action
|
||||
}
|
||||
|
||||
player_reward_features = {
|
||||
"player_id": player.player_id,
|
||||
"player_exp": player.stats.exp
|
||||
}
|
||||
|
||||
player_state_features = {
|
||||
"player_id": player.player_id,
|
||||
"player_position": player.rect.center,
|
||||
"player role": player.stats.role_id,
|
||||
"player_health": player.stats.health,
|
||||
"player_energy": player.stats.energy,
|
||||
"player_attack": player.stats.attack,
|
||||
"player_magic": player.stats.magic,
|
||||
"player_speed": player.stats.speed,
|
||||
"player_vulnerable": int(player._input.combat.vulnerable),
|
||||
"player_can_move": int(player._input.can_move),
|
||||
"player_attacking": int(player._input.attacking),
|
||||
"player_can_rotate_weapon": int(player._input.can_rotate_weapon),
|
||||
"playercan_swap_magic": int(player._input.can_swap_magic)
|
||||
}
|
||||
|
||||
distances_directions = []
|
||||
|
||||
for distance, direction, enemy in player.distance_direction_from_enemy:
|
||||
distances_directions.append({
|
||||
"enemy_id": enemy.stats.monster_id,
|
||||
"enemy_status": 0 if enemy.animation.status == "idle" else (1 if enemy.animation.status == "move" else 2),
|
||||
"enemy_health": enemy.stats.health,
|
||||
"enemy_attack": enemy.stats.attack,
|
||||
"enemy_speed": enemy.stats.speed,
|
||||
"enemy_attack_radius": enemy.stats.attack_radius,
|
||||
"enemy_notice_radius": enemy.stats.notice_radius,
|
||||
"enemy_exp": enemy.stats.exp,
|
||||
"enemy_distance": distance,
|
||||
"enemy_direction": direction
|
||||
})
|
||||
|
||||
player_state_features["enemies"] = distances_directions
|
||||
self.reward_features.append(player_reward_features)
|
||||
self.state_features.append(player_state_features)
|
||||
self.action_features.append(player_action_features)
|
||||
|
||||
def convert_features_to_tensor(self):
|
||||
|
||||
for features in self.state_features:
|
||||
info_array = []
|
||||
|
||||
# Adding player features to tensor
|
||||
player_info = [
|
||||
features['player_position'][0],
|
||||
features['player_position'][1],
|
||||
features['player role'],
|
||||
features['player_health'],
|
||||
features['player_energy'],
|
||||
features['player_attack'],
|
||||
features['player_magic'],
|
||||
features['player_speed'],
|
||||
features['player_vulnerable'],
|
||||
features['player_can_move'],
|
||||
features['player_attacking'],
|
||||
features['player_can_rotate_weapon'],
|
||||
features['playercan_swap_magic'],
|
||||
]
|
||||
info_array.extend(player_info)
|
||||
|
||||
# Adding enemy features per player
|
||||
for enemy in features['enemies']:
|
||||
enemy_info = [
|
||||
enemy['enemy_id'],
|
||||
enemy['enemy_status'],
|
||||
enemy['enemy_health'],
|
||||
enemy['enemy_attack'],
|
||||
enemy['enemy_speed'],
|
||||
enemy['enemy_attack_radius'],
|
||||
enemy['enemy_notice_radius'],
|
||||
enemy['enemy_exp'],
|
||||
enemy['enemy_distance'],
|
||||
enemy['enemy_direction'][0],
|
||||
enemy['enemy_direction'][1]
|
||||
]
|
||||
info_array.extend(enemy_info)
|
||||
|
||||
state_tensor = torch.tensor(
|
||||
np.array(info_array, dtype=np.float32))
|
||||
|
||||
for player in self.level.player_sprites:
|
||||
if player.player_id == features["player_id"]:
|
||||
player.state_tensor = state_tensor
|
||||
|
||||
for features in self.action_features:
|
||||
info_array = []
|
||||
|
||||
# Adding action features
|
||||
action_info = [
|
||||
features["player_action"]
|
||||
]
|
||||
|
||||
action_tensor = torch.tensor(
|
||||
np.array(action_info, dtype=np.float32))
|
||||
|
||||
for player in self.level.player_sprites:
|
||||
if player.player_id == features["player_id"]:
|
||||
player.action_tensor = action_tensor
|
||||
|
||||
for features in self.reward_features:
|
||||
info_array = []
|
||||
|
||||
# Adding reward features
|
||||
reward_info = [
|
||||
features["player_exp"]
|
||||
]
|
||||
|
||||
reward_tensor = torch.tensor(
|
||||
np.array(reward_info, dtype=np.float32))
|
||||
|
||||
for player in self.level.player_sprites:
|
||||
if player.player_id == features["player_id"]:
|
||||
player.reward_tensor = reward_tensor
|
||||
|
||||
def run(self):
|
||||
|
||||
for event in pygame.event.get():
|
||||
|
@ -178,8 +45,51 @@ class Game:
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
n_games = 300
|
||||
|
||||
figure_file = 'plots/score.png'
|
||||
score_history = []
|
||||
best_score = 0
|
||||
avg_score = 0
|
||||
|
||||
agent_list = []
|
||||
|
||||
game_len = 10000
|
||||
|
||||
game = Game()
|
||||
for i in range(0, 10000):
|
||||
|
||||
for i in tqdm(range(n_games)):
|
||||
# TODO: Make game.level.reset_map() so we don't __init__ everything all the time (such a waste)
|
||||
game.level.__init__()
|
||||
# TODO: Make game.level.reset_map() so we don't pull out and load the agent every time (There is -definitevly- a better way)
|
||||
for player in game.level.player_sprites:
|
||||
for player_id, agent in agent_list:
|
||||
if player.player_id == player_id:
|
||||
player.agent = agent
|
||||
agent_list = []
|
||||
done = False
|
||||
score = 0
|
||||
for _ in range(game_len):
|
||||
if not game.level.done:
|
||||
game.run()
|
||||
print(i)
|
||||
else:
|
||||
break
|
||||
for player in game.level.player_sprites:
|
||||
agent_list.append((player.player_id, player.agent))
|
||||
|
||||
if i == n_games-1 and game.level.enemy_sprites != []:
|
||||
for player in game.level.player_sprites:
|
||||
player.stats.exp -= 5
|
||||
player.update()
|
||||
|
||||
for player in game.level.player_sprites:
|
||||
player.agent.save_models()
|
||||
|
||||
# TODO: Make it so that scores appear here for each player
|
||||
# score_history.append(game.level.player.score)
|
||||
# print(score)
|
||||
# avg_score = np.mean(score_history[-100:])
|
||||
|
||||
# if avg_score > best_score:
|
||||
# best_score = avg_score
|
||||
# game.level.player.agent.save_models()
|
||||
|
|
BIN
tmp/ppo/actor_torch_ppo
Normal file
BIN
tmp/ppo/actor_torch_ppo
Normal file
Binary file not shown.
BIN
tmp/ppo/critic_torch_ppo
Normal file
BIN
tmp/ppo/critic_torch_ppo
Normal file
Binary file not shown.
Loading…
Reference in a new issue