pneuma-pygame/agent/agent.py

104 lines
3.8 KiB
Python
Raw Normal View History

2023-11-14 21:44:43 +00:00
import numpy as np
import torch as T
2023-11-14 21:44:43 +00:00
from .brain import ActorNetwork, CriticNetwork, PPOMemory
2023-11-13 18:09:41 +00:00
class Agent:
2023-11-13 18:09:41 +00:00
2023-11-14 21:44:43 +00:00
def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003, policy_clip=0.2, batch_size=64, N=2048, n_epochs=10, gae_lambda=0.95):
self.gamma = gamma
self.policy_clip = policy_clip
self.n_epochs = n_epochs
self.gae_lambda = gae_lambda
2023-11-14 21:44:43 +00:00
print("Preparing Actor model...")
self.actor = ActorNetwork(input_dims, n_actions, alpha)
print(f"Actor network activated using {self.actor.device}")
print("\nPreparing Critic model...")
self.critic = CriticNetwork(input_dims, alpha)
print(f"Critic network activated using {self.critic.device}")
self.memory = PPOMemory(batch_size)
2023-11-14 21:44:43 +00:00
def remember(self, state, action, probs, vals, reward, done):
self.memory.store_memory(state, action, probs, vals, reward, done)
2023-11-14 21:44:43 +00:00
def save_models(self):
print('... saving models ...')
self.actor.save_checkpoint()
2023-11-13 18:09:41 +00:00
self.critic.save_checkpoint()
print('... done ...')
2023-11-14 21:44:43 +00:00
def load_models(self):
2023-11-14 21:44:43 +00:00
print('... loading models ...')
self.actor.load_checkpoint()
2023-11-13 18:09:41 +00:00
self.critic.load_checkpoint()
print('.. done ...')
2023-11-14 21:44:43 +00:00
def choose_action(self, observation):
2023-11-14 21:44:43 +00:00
state = observation.to(self.actor.device, dtype=T.float)
dist = self.actor(state)
value = self.critic(state)
action = dist.sample()
2023-11-14 21:44:43 +00:00
probs = T.squeeze(dist.log_prob(action)).item()
action = T.squeeze(action).item()
value = T.squeeze(value).item()
2023-11-14 21:44:43 +00:00
return action, probs, value
2023-11-14 21:44:43 +00:00
def learn(self):
for _ in range(self.n_epochs):
state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, done_arr, batches = self.memory.generate_batches()
2023-11-14 21:44:43 +00:00
values = vals_arr
2023-11-14 21:44:43 +00:00
advantage = np.zeros(len(reward_arr), dtype=np.float32)
for t in range(len(reward_arr)-1):
discount = 1
a_t = 0
for k in range(t, len(reward_arr)-1):
2023-11-14 21:44:43 +00:00
a_t += discount * \
(reward_arr[k] + self.gamma*values[k+1]
* (1-int(done_arr[k])) - values[k])
discount *= self.gamma * self.gae_lambda
advantage[t] = a_t
2023-11-14 21:44:43 +00:00
advantage = T.tensor(advantage).to(self.actor.device)
values = T.tensor(values).to(self.actor.device)
for batch in batches:
2023-11-14 21:44:43 +00:00
states = T.tensor(state_arr[batch], dtype=T.float).to(
self.actor.device)
old_probs = T.tensor(old_probs_arr[batch]).to(
self.actor.device)
actions = T.tensor(action_arr[batch]).to(self.actor.device)
2023-11-14 21:44:43 +00:00
dist = self.actor(states)
critic_value = self.critic(states)
2023-11-14 21:44:43 +00:00
critic_value = T.squeeze(critic_value)
2023-11-14 21:44:43 +00:00
new_probs = dist.log_prob(actions)
prob_ratio = new_probs.exp() / old_probs.exp()
weighted_probs = advantage[batch] * prob_ratio
2023-11-14 21:44:43 +00:00
weighted_clipped_probs = T.clamp(
prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch]
actor_loss = -T.min(weighted_probs,
weighted_clipped_probs).mean()
returns = advantage[batch] + values[batch]
critic_loss = (returns - critic_value)**2
critic_loss = critic_loss.mean()
2023-11-14 21:44:43 +00:00
total_loss = actor_loss + 0.5*critic_loss
2023-11-14 21:44:43 +00:00
self.actor.optimizer.zero_grad()
self.critic.optimizer.zero_grad()
total_loss.backward()
self.actor.optimizer.step()
self.critic.optimizer.step()
2023-11-14 21:44:43 +00:00
self.memory.clear_memory()