Made the agent advantage accessible

This commit is contained in:
Vasilis Valatsos 2024-03-07 10:43:14 +02:00
parent feece23330
commit 30dda47f95

View file

@ -61,7 +61,7 @@ class Agent:
state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.generate_batches() state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.generate_batches()
values = vals_arr values = vals_arr
advantage = np.zeros(len(reward_arr), dtype=np.float64) self.advantage = np.zeros(len(reward_arr), dtype=np.float64)
for t in range(len(reward_arr)-1): for t in range(len(reward_arr)-1):
discount = 1 discount = 1
@ -71,8 +71,8 @@ class Agent:
(reward_arr[k] + self.gamma*values[k+1] (reward_arr[k] + self.gamma*values[k+1]
* (1-int(dones_arr[k])) - values[k]) * (1-int(dones_arr[k])) - values[k])
discount *= self.gamma * self.gae_lambda discount *= self.gamma * self.gae_lambda
advantage[t] = a_t self.advantage[t] = a_t
advantage = T.tensor(advantage).to(self.actor.device) self.advantage = T.tensor(self.advantage).to(self.actor.device)
values = T.tensor(values).to(self.actor.device) values = T.tensor(values).to(self.actor.device)
@ -90,15 +90,15 @@ class Agent:
new_probs = dist.log_prob(actions) new_probs = dist.log_prob(actions)
prob_ratio = new_probs.exp() / old_probs.exp() prob_ratio = new_probs.exp() / old_probs.exp()
weighted_probs = advantage[batch] * prob_ratio weighted_probs = self.advantage[batch] * prob_ratio
weighted_clipped_probs = T.clamp( weighted_clipped_probs = T.clamp(
prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch] prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*self.advantage[batch]
self.actor_loss = -T.min(weighted_probs, self.actor_loss = -T.min(weighted_probs,
weighted_clipped_probs).mean() weighted_clipped_probs).mean()
returns = advantage[batch] + values[batch] returns = self.advantage[batch] + values[batch]
self.critic_loss = (returns - critic_value)**2 self.critic_loss = (returns - critic_value)**2
self.critic_loss = self.critic_loss.mean() self.critic_loss = self.critic_loss.mean()