From 1ab8df01eaf5f4aa9c6b5788ab30d58be65d0e89 Mon Sep 17 00:00:00 2001 From: Vasilis Valatsos Date: Thu, 7 Mar 2024 10:52:15 +0200 Subject: [PATCH] Converted advantage to float as gae to pass to main.py --- main.py | 2 +- ml/ppo/agent.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index e474a41..36ee202 100644 --- a/main.py +++ b/main.py @@ -133,7 +133,7 @@ def main(): = player.agent.entropy episode_advantage[player.player_id][learn_iters % learnings_per_episode]\ - = player.agent.advantage + = player.agent.gae learn_iters += 1 diff --git a/ml/ppo/agent.py b/ml/ppo/agent.py index d72f5cb..f9d8ed3 100644 --- a/ml/ppo/agent.py +++ b/ml/ppo/agent.py @@ -61,7 +61,7 @@ class Agent: state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.generate_batches() values = vals_arr - self.advantage = np.zeros(len(reward_arr), dtype=np.float64) + advantage = np.zeros(len(reward_arr), dtype=np.float64) for t in range(len(reward_arr)-1): discount = 1 @@ -71,8 +71,9 @@ class Agent: (reward_arr[k] + self.gamma*values[k+1] * (1-int(dones_arr[k])) - values[k]) discount *= self.gamma * self.gae_lambda - self.advantage[t] = a_t - self.advantage = T.tensor(self.advantage).to(self.actor.device) + advantage[t] = a_t + self.gae = np.sum(advantage) + advantage = T.tensor(advantage).to(self.actor.device) values = T.tensor(values).to(self.actor.device) @@ -90,15 +91,15 @@ class Agent: new_probs = dist.log_prob(actions) prob_ratio = new_probs.exp() / old_probs.exp() - weighted_probs = self.advantage[batch] * prob_ratio + weighted_probs = advantage[batch] * prob_ratio weighted_clipped_probs = T.clamp( - prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*self.advantage[batch] + prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch] self.actor_loss = -T.min(weighted_probs, weighted_clipped_probs).mean() - returns = self.advantage[batch] + values[batch] + returns = advantage[batch] + values[batch] self.critic_loss = (returns - critic_value)**2 self.critic_loss = self.critic_loss.mean()