Updated adam to reduce plasticity and made the agent more observant, amongst other things
This commit is contained in:
parent
7eb7228a8c
commit
78c298e072
5 changed files with 46 additions and 41 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -164,3 +164,5 @@ cython_debug/
|
|||
|
||||
# Random stuff
|
||||
__pycache__/
|
||||
chkpts/
|
||||
figures/
|
|
@ -86,8 +86,8 @@ class Player(pygame.sprite.Sprite):
|
|||
print("Attempting to load models ...")
|
||||
try:
|
||||
self.agent.load_models(
|
||||
actr_chkpt=f"run{load}/A{self.player_id}",
|
||||
crtc_chkpt=f"run{load}/C{self.player_id}"
|
||||
actr_chkpt=f"{chkpt_dir}/../run{load}/A{self.player_id}",
|
||||
crtc_chkpt=f"{chkpt_dir}/../run{load}/C{self.player_id}"
|
||||
)
|
||||
print("Models loaded ...\n")
|
||||
|
||||
|
@ -160,9 +160,9 @@ class Player(pygame.sprite.Sprite):
|
|||
if hasattr(self, 'state_features'):
|
||||
self.old_state_features = self.state_features
|
||||
|
||||
self.reward = self.stats.exp\
|
||||
+ self.stats.health/self.stats.stats['health'] - 1\
|
||||
- nearest_dist/np.sqrt(np.sum(self.map_edge))
|
||||
self.reward = self.stats.exp
|
||||
# + self.stats.health/self.stats.stats['health'] - 1\
|
||||
# - nearest_dist/np.sqrt(np.sum(self.map_edge))
|
||||
|
||||
self.state_features = [
|
||||
self.animation.rect.center[0]/self.map_edge[0],
|
||||
|
@ -173,27 +173,27 @@ class Player(pygame.sprite.Sprite):
|
|||
self.stats.energy/self.stats.stats['energy'],
|
||||
]
|
||||
|
||||
self.state_features.extend([
|
||||
nearest_dist/np.sqrt(np.sum(self.map_edge)),
|
||||
nearest_en_dir[0],
|
||||
nearest_en_dir[1],
|
||||
nearest_enemy.stats.exp
|
||||
])
|
||||
# self.state_features.extend([
|
||||
# nearest_dist/np.sqrt(np.sum(self.map_edge)),
|
||||
# nearest_en_dir[0],
|
||||
# nearest_en_dir[1],
|
||||
# nearest_enemy.stats.exp
|
||||
# ])
|
||||
|
||||
# for distance, direction, enemy in self.distance_direction_from_enemy:
|
||||
# self.state_features.extend([
|
||||
#
|
||||
# distance/np.sqrt(np.sum(self.map_edge)),
|
||||
#
|
||||
# direction[0],
|
||||
#
|
||||
# direction[1],
|
||||
#
|
||||
# enemy.stats.health /
|
||||
# enemy.stats.monster_info['health'],
|
||||
#
|
||||
# enemy.stats.exp,
|
||||
# ])
|
||||
for distance, direction, enemy in sorted_distances[:5]:
|
||||
self.state_features.extend([
|
||||
|
||||
distance/np.sqrt(np.sum(self.map_edge)),
|
||||
|
||||
direction[0],
|
||||
|
||||
direction[1],
|
||||
|
||||
enemy.stats.health /
|
||||
enemy.stats.monster_info['health'],
|
||||
|
||||
enemy.stats.exp,
|
||||
])
|
||||
|
||||
if hasattr(self, 'num_features'):
|
||||
while len(self.state_features) < self.num_features:
|
||||
|
|
BIN
figures/.DS_Store
vendored
BIN
figures/.DS_Store
vendored
Binary file not shown.
|
@ -52,7 +52,7 @@ class PPOMemory:
|
|||
|
||||
class ActorNetwork(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim, alpha, fc1_dims=1024, fc2_dims=1024, chkpt_dir='tmp/ppo'):
|
||||
def __init__(self, input_dim, output_dim, alpha, fc1_dims=512, fc2_dims=512, chkpt_dir='tmp'):
|
||||
super(ActorNetwork, self).__init__()
|
||||
|
||||
self.chkpt_dir = chkpt_dir
|
||||
|
@ -68,7 +68,7 @@ class ActorNetwork(nn.Module):
|
|||
nn.Softmax(dim=-1)
|
||||
)
|
||||
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5)
|
||||
|
||||
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
|
||||
|
||||
|
@ -80,18 +80,19 @@ class ActorNetwork(nn.Module):
|
|||
|
||||
return dist
|
||||
|
||||
def save_checkpoint(self, filename='actor_torch_ppo'):
|
||||
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
|
||||
def save_checkpoint(self, filename):
|
||||
T.save(self.state_dict(), os.path.join(filename))
|
||||
|
||||
def load_checkpoint(self, filename='actor_torch_ppo'):
|
||||
def load_checkpoint(self, filename):
|
||||
print(filename)
|
||||
self.load_state_dict(
|
||||
T.load(os.path.join(self.chkpt_dir, filename),
|
||||
T.load(os.path.join(filename),
|
||||
map_location=self.device))
|
||||
|
||||
|
||||
class CriticNetwork(nn.Module):
|
||||
|
||||
def __init__(self, input_dims, alpha, fc1_dims=4096, fc2_dims=4096, chkpt_dir='tmp/ppo'):
|
||||
def __init__(self, input_dims, alpha, fc1_dims=2048, fc2_dims=2048, chkpt_dir='tmp'):
|
||||
super(CriticNetwork, self).__init__()
|
||||
|
||||
self.chkpt_dir = chkpt_dir
|
||||
|
@ -105,16 +106,16 @@ class CriticNetwork(nn.Module):
|
|||
nn.LeakyReLU(),
|
||||
nn.Linear(fc1_dims, fc2_dims),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(fc1_dims, fc2_dims),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(fc1_dims, fc2_dims),
|
||||
nn.LeakyReLU(),
|
||||
# nn.Linear(fc1_dims, fc2_dims),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Linear(fc1_dims, fc2_dims),
|
||||
# nn.LeakyReLU(),
|
||||
nn.Linear(fc1_dims, fc2_dims),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(fc2_dims, 1)
|
||||
)
|
||||
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5)
|
||||
|
||||
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
|
||||
|
||||
|
@ -124,10 +125,11 @@ class CriticNetwork(nn.Module):
|
|||
value = self.critic(state)
|
||||
return value
|
||||
|
||||
def save_checkpoint(self, filename='critic_torch_ppo'):
|
||||
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
|
||||
def save_checkpoint(self, filename):
|
||||
T.save(self.state_dict(), os.path.join(filename))
|
||||
|
||||
def load_checkpoint(self, filename='critic_torch_ppo'):
|
||||
def load_checkpoint(self, filename):
|
||||
print(filename)
|
||||
self.load_state_dict(
|
||||
T.load(os.path.join(self.chkpt_dir, filename),
|
||||
T.load(os.path.join(filename),
|
||||
map_location=self.device))
|
||||
|
|
|
@ -25,6 +25,7 @@ def plot_avg_time(time_steps, num_players, fig_path):
|
|||
for player in time_steps:
|
||||
plt.plot(player)
|
||||
plt.savefig(os.path.join(fig_path, 'avg_time.png'))
|
||||
plt.close()
|
||||
|
||||
def plot_score(scores, num_players, figure_path):
|
||||
|
||||
|
|
Loading…
Reference in a new issue