pneuma-pygame/agents/ppo/agent.py

122 lines
4.3 KiB
Python
Raw Normal View History

2023-11-14 21:44:43 +00:00
import numpy as np
import torch as T
2023-12-10 19:15:40 +00:00
from tqdm import tqdm
2023-11-14 21:44:43 +00:00
from .brain import ActorNetwork, CriticNetwork, PPOMemory
2023-11-13 18:09:41 +00:00
class Agent:
2023-11-13 18:09:41 +00:00
2023-11-17 02:19:03 +00:00
def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003,
2023-12-10 19:15:40 +00:00
policy_clip=0.2, batch_size=64, n_epochs=10,
gae_lambda=0.95, entropy_coef=0.001, chkpt_dir='tmp/ppo'):
2023-11-14 21:44:43 +00:00
self.gamma = gamma
self.policy_clip = policy_clip
self.n_epochs = n_epochs
self.gae_lambda = gae_lambda
self.entropy_coef = entropy_coef
2023-11-14 21:44:43 +00:00
self.actor = ActorNetwork(
input_dims, n_actions, alpha, chkpt_dir=chkpt_dir)
self.critic = CriticNetwork(
input_dims, alpha, chkpt_dir=chkpt_dir)
self.memory = PPOMemory(batch_size)
2023-11-14 21:44:43 +00:00
def remember(self, state, action, probs, vals, reward, done):
self.memory.store_memory(state, action, probs, vals, reward, done)
2023-11-14 21:44:43 +00:00
2023-11-24 20:15:50 +00:00
def save_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'):
2023-11-23 15:37:02 +00:00
self.actor.save_checkpoint(actr_chkpt)
self.critic.save_checkpoint(crtc_chkpt)
2023-11-14 21:44:43 +00:00
2023-11-24 20:15:50 +00:00
def load_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'):
2023-11-23 15:37:02 +00:00
self.actor.load_checkpoint(actr_chkpt)
self.critic.load_checkpoint(crtc_chkpt)
2023-11-14 21:44:43 +00:00
def choose_action(self, observation):
2023-11-20 00:51:54 +00:00
state = T.tensor(observation, dtype=T.float).to(self.actor.device)
dist = self.actor(state)
value = self.critic(state)
action = dist.sample()
2023-11-14 21:44:43 +00:00
probs = T.squeeze(dist.log_prob(action)).item()
action = T.squeeze(action).item()
value = T.squeeze(value).item()
2023-11-19 03:27:47 +00:00
self.entropy = dist.entropy().mean().item()
return action, probs, value
2023-11-14 21:44:43 +00:00
def learn(self):
2023-12-10 19:15:40 +00:00
for _ in tqdm(range(self.n_epochs),
desc='Learning...',
dynamic_ncols=True,
leave=False,
ascii=True):
2023-11-17 02:19:03 +00:00
state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.generate_batches()
2023-11-14 21:44:43 +00:00
values = vals_arr
2023-11-23 15:37:02 +00:00
advantage = np.zeros(len(reward_arr), dtype=np.float64)
2023-11-14 21:44:43 +00:00
for t in range(len(reward_arr)-1):
discount = 1
a_t = 0
for k in range(t, len(reward_arr)-1):
2023-11-14 21:44:43 +00:00
a_t += discount * \
(reward_arr[k] + self.gamma*values[k+1]
2023-11-17 02:19:03 +00:00
* (1-int(dones_arr[k])) - values[k])
discount *= self.gamma * self.gae_lambda
advantage[t] = a_t
2023-11-14 21:44:43 +00:00
advantage = T.tensor(advantage).to(self.actor.device)
values = T.tensor(values).to(self.actor.device)
2023-11-17 02:19:03 +00:00
for batch in batches:
2023-11-14 21:44:43 +00:00
states = T.tensor(state_arr[batch], dtype=T.float).to(
self.actor.device)
old_probs = T.tensor(old_probs_arr[batch]).to(
self.actor.device)
actions = T.tensor(action_arr[batch]).to(self.actor.device)
2023-11-14 21:44:43 +00:00
dist = self.actor(states)
critic_value = self.critic(states)
2023-11-14 21:44:43 +00:00
critic_value = T.squeeze(critic_value)
2023-11-14 21:44:43 +00:00
new_probs = dist.log_prob(actions)
prob_ratio = new_probs.exp() / old_probs.exp()
weighted_probs = advantage[batch] * prob_ratio
2023-11-14 21:44:43 +00:00
weighted_clipped_probs = T.clamp(
prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch]
self.actor_loss = -T.min(weighted_probs,
weighted_clipped_probs).mean()
2023-11-14 21:44:43 +00:00
returns = advantage[batch] + values[batch]
self.critic_loss = (returns - critic_value)**2
self.critic_loss = self.critic_loss.mean()
2023-11-14 21:44:43 +00:00
self.total_loss = self.actor_loss + 0.5 * \
self.critic_loss - self.entropy_coef*self.entropy
self.actor.optimizer.zero_grad()
self.critic.optimizer.zero_grad()
self.total_loss.backward()
2023-12-10 19:15:40 +00:00
T.nn.utils.clip_grad_norm_(
self.actor.parameters(), max_norm=2)
T.nn.utils.clip_grad_norm_(
self.critic.parameters(), max_norm=2)
self.actor.optimizer.step()
self.critic.optimizer.step()
2023-11-14 21:44:43 +00:00
self.memory.clear_memory()