Updated adam to reduce plasticity and made the agent more observant, amongst other things
This commit is contained in:
parent
7eb7228a8c
commit
78c298e072
5 changed files with 46 additions and 41 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -164,3 +164,5 @@ cython_debug/
|
||||||
|
|
||||||
# Random stuff
|
# Random stuff
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
chkpts/
|
||||||
|
figures/
|
|
@ -86,8 +86,8 @@ class Player(pygame.sprite.Sprite):
|
||||||
print("Attempting to load models ...")
|
print("Attempting to load models ...")
|
||||||
try:
|
try:
|
||||||
self.agent.load_models(
|
self.agent.load_models(
|
||||||
actr_chkpt=f"run{load}/A{self.player_id}",
|
actr_chkpt=f"{chkpt_dir}/../run{load}/A{self.player_id}",
|
||||||
crtc_chkpt=f"run{load}/C{self.player_id}"
|
crtc_chkpt=f"{chkpt_dir}/../run{load}/C{self.player_id}"
|
||||||
)
|
)
|
||||||
print("Models loaded ...\n")
|
print("Models loaded ...\n")
|
||||||
|
|
||||||
|
@ -160,9 +160,9 @@ class Player(pygame.sprite.Sprite):
|
||||||
if hasattr(self, 'state_features'):
|
if hasattr(self, 'state_features'):
|
||||||
self.old_state_features = self.state_features
|
self.old_state_features = self.state_features
|
||||||
|
|
||||||
self.reward = self.stats.exp\
|
self.reward = self.stats.exp
|
||||||
+ self.stats.health/self.stats.stats['health'] - 1\
|
# + self.stats.health/self.stats.stats['health'] - 1\
|
||||||
- nearest_dist/np.sqrt(np.sum(self.map_edge))
|
# - nearest_dist/np.sqrt(np.sum(self.map_edge))
|
||||||
|
|
||||||
self.state_features = [
|
self.state_features = [
|
||||||
self.animation.rect.center[0]/self.map_edge[0],
|
self.animation.rect.center[0]/self.map_edge[0],
|
||||||
|
@ -173,28 +173,28 @@ class Player(pygame.sprite.Sprite):
|
||||||
self.stats.energy/self.stats.stats['energy'],
|
self.stats.energy/self.stats.stats['energy'],
|
||||||
]
|
]
|
||||||
|
|
||||||
self.state_features.extend([
|
|
||||||
nearest_dist/np.sqrt(np.sum(self.map_edge)),
|
|
||||||
nearest_en_dir[0],
|
|
||||||
nearest_en_dir[1],
|
|
||||||
nearest_enemy.stats.exp
|
|
||||||
])
|
|
||||||
|
|
||||||
# for distance, direction, enemy in self.distance_direction_from_enemy:
|
|
||||||
# self.state_features.extend([
|
# self.state_features.extend([
|
||||||
#
|
# nearest_dist/np.sqrt(np.sum(self.map_edge)),
|
||||||
# distance/np.sqrt(np.sum(self.map_edge)),
|
# nearest_en_dir[0],
|
||||||
#
|
# nearest_en_dir[1],
|
||||||
# direction[0],
|
# nearest_enemy.stats.exp
|
||||||
#
|
|
||||||
# direction[1],
|
|
||||||
#
|
|
||||||
# enemy.stats.health /
|
|
||||||
# enemy.stats.monster_info['health'],
|
|
||||||
#
|
|
||||||
# enemy.stats.exp,
|
|
||||||
# ])
|
# ])
|
||||||
|
|
||||||
|
for distance, direction, enemy in sorted_distances[:5]:
|
||||||
|
self.state_features.extend([
|
||||||
|
|
||||||
|
distance/np.sqrt(np.sum(self.map_edge)),
|
||||||
|
|
||||||
|
direction[0],
|
||||||
|
|
||||||
|
direction[1],
|
||||||
|
|
||||||
|
enemy.stats.health /
|
||||||
|
enemy.stats.monster_info['health'],
|
||||||
|
|
||||||
|
enemy.stats.exp,
|
||||||
|
])
|
||||||
|
|
||||||
if hasattr(self, 'num_features'):
|
if hasattr(self, 'num_features'):
|
||||||
while len(self.state_features) < self.num_features:
|
while len(self.state_features) < self.num_features:
|
||||||
self.state_features.append(0)
|
self.state_features.append(0)
|
||||||
|
|
BIN
figures/.DS_Store
vendored
BIN
figures/.DS_Store
vendored
Binary file not shown.
|
@ -52,7 +52,7 @@ class PPOMemory:
|
||||||
|
|
||||||
class ActorNetwork(nn.Module):
|
class ActorNetwork(nn.Module):
|
||||||
|
|
||||||
def __init__(self, input_dim, output_dim, alpha, fc1_dims=1024, fc2_dims=1024, chkpt_dir='tmp/ppo'):
|
def __init__(self, input_dim, output_dim, alpha, fc1_dims=512, fc2_dims=512, chkpt_dir='tmp'):
|
||||||
super(ActorNetwork, self).__init__()
|
super(ActorNetwork, self).__init__()
|
||||||
|
|
||||||
self.chkpt_dir = chkpt_dir
|
self.chkpt_dir = chkpt_dir
|
||||||
|
@ -68,7 +68,7 @@ class ActorNetwork(nn.Module):
|
||||||
nn.Softmax(dim=-1)
|
nn.Softmax(dim=-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
|
self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5)
|
||||||
|
|
||||||
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
|
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
@ -80,18 +80,19 @@ class ActorNetwork(nn.Module):
|
||||||
|
|
||||||
return dist
|
return dist
|
||||||
|
|
||||||
def save_checkpoint(self, filename='actor_torch_ppo'):
|
def save_checkpoint(self, filename):
|
||||||
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
|
T.save(self.state_dict(), os.path.join(filename))
|
||||||
|
|
||||||
def load_checkpoint(self, filename='actor_torch_ppo'):
|
def load_checkpoint(self, filename):
|
||||||
|
print(filename)
|
||||||
self.load_state_dict(
|
self.load_state_dict(
|
||||||
T.load(os.path.join(self.chkpt_dir, filename),
|
T.load(os.path.join(filename),
|
||||||
map_location=self.device))
|
map_location=self.device))
|
||||||
|
|
||||||
|
|
||||||
class CriticNetwork(nn.Module):
|
class CriticNetwork(nn.Module):
|
||||||
|
|
||||||
def __init__(self, input_dims, alpha, fc1_dims=4096, fc2_dims=4096, chkpt_dir='tmp/ppo'):
|
def __init__(self, input_dims, alpha, fc1_dims=2048, fc2_dims=2048, chkpt_dir='tmp'):
|
||||||
super(CriticNetwork, self).__init__()
|
super(CriticNetwork, self).__init__()
|
||||||
|
|
||||||
self.chkpt_dir = chkpt_dir
|
self.chkpt_dir = chkpt_dir
|
||||||
|
@ -105,16 +106,16 @@ class CriticNetwork(nn.Module):
|
||||||
nn.LeakyReLU(),
|
nn.LeakyReLU(),
|
||||||
nn.Linear(fc1_dims, fc2_dims),
|
nn.Linear(fc1_dims, fc2_dims),
|
||||||
nn.LeakyReLU(),
|
nn.LeakyReLU(),
|
||||||
nn.Linear(fc1_dims, fc2_dims),
|
# nn.Linear(fc1_dims, fc2_dims),
|
||||||
nn.LeakyReLU(),
|
# nn.LeakyReLU(),
|
||||||
nn.Linear(fc1_dims, fc2_dims),
|
# nn.Linear(fc1_dims, fc2_dims),
|
||||||
nn.LeakyReLU(),
|
# nn.LeakyReLU(),
|
||||||
nn.Linear(fc1_dims, fc2_dims),
|
nn.Linear(fc1_dims, fc2_dims),
|
||||||
nn.LeakyReLU(),
|
nn.LeakyReLU(),
|
||||||
nn.Linear(fc2_dims, 1)
|
nn.Linear(fc2_dims, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
|
self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5)
|
||||||
|
|
||||||
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
|
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
@ -124,10 +125,11 @@ class CriticNetwork(nn.Module):
|
||||||
value = self.critic(state)
|
value = self.critic(state)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def save_checkpoint(self, filename='critic_torch_ppo'):
|
def save_checkpoint(self, filename):
|
||||||
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
|
T.save(self.state_dict(), os.path.join(filename))
|
||||||
|
|
||||||
def load_checkpoint(self, filename='critic_torch_ppo'):
|
def load_checkpoint(self, filename):
|
||||||
|
print(filename)
|
||||||
self.load_state_dict(
|
self.load_state_dict(
|
||||||
T.load(os.path.join(self.chkpt_dir, filename),
|
T.load(os.path.join(filename),
|
||||||
map_location=self.device))
|
map_location=self.device))
|
||||||
|
|
|
@ -25,6 +25,7 @@ def plot_avg_time(time_steps, num_players, fig_path):
|
||||||
for player in time_steps:
|
for player in time_steps:
|
||||||
plt.plot(player)
|
plt.plot(player)
|
||||||
plt.savefig(os.path.join(fig_path, 'avg_time.png'))
|
plt.savefig(os.path.join(fig_path, 'avg_time.png'))
|
||||||
|
plt.close()
|
||||||
|
|
||||||
def plot_score(scores, num_players, figure_path):
|
def plot_score(scores, num_players, figure_path):
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue