2023-11-14 21:44:43 +00:00
|
|
|
import numpy as np
|
|
|
|
import torch as T
|
2023-07-11 00:25:44 +00:00
|
|
|
|
2023-12-10 19:15:40 +00:00
|
|
|
from tqdm import tqdm
|
|
|
|
|
2023-11-14 21:44:43 +00:00
|
|
|
from .brain import ActorNetwork, CriticNetwork, PPOMemory
|
2023-11-13 18:09:41 +00:00
|
|
|
|
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
class Agent:
|
2023-11-13 18:09:41 +00:00
|
|
|
|
2023-11-17 02:19:03 +00:00
|
|
|
def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003,
|
2023-12-10 19:15:40 +00:00
|
|
|
policy_clip=0.2, batch_size=64, n_epochs=10,
|
2023-12-08 21:08:25 +00:00
|
|
|
gae_lambda=0.95, entropy_coef=0.001, chkpt_dir='tmp/ppo'):
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
self.gamma = gamma
|
|
|
|
self.policy_clip = policy_clip
|
|
|
|
self.n_epochs = n_epochs
|
|
|
|
self.gae_lambda = gae_lambda
|
2023-12-08 21:08:25 +00:00
|
|
|
self.entropy_coef = entropy_coef
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-11-29 10:53:30 +00:00
|
|
|
self.actor = ActorNetwork(
|
|
|
|
input_dims, n_actions, alpha, chkpt_dir=chkpt_dir)
|
|
|
|
|
|
|
|
self.critic = CriticNetwork(
|
|
|
|
input_dims, alpha, chkpt_dir=chkpt_dir)
|
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
self.memory = PPOMemory(batch_size)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
def remember(self, state, action, probs, vals, reward, done):
|
|
|
|
self.memory.store_memory(state, action, probs, vals, reward, done)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-11-24 20:15:50 +00:00
|
|
|
def save_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'):
|
2023-11-23 15:37:02 +00:00
|
|
|
self.actor.save_checkpoint(actr_chkpt)
|
|
|
|
self.critic.save_checkpoint(crtc_chkpt)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-11-24 20:15:50 +00:00
|
|
|
def load_models(self, actr_chkpt='actor_ppo', crtc_chkpt='critic_ppo'):
|
2023-11-23 15:37:02 +00:00
|
|
|
self.actor.load_checkpoint(actr_chkpt)
|
|
|
|
self.critic.load_checkpoint(crtc_chkpt)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
def choose_action(self, observation):
|
2023-11-20 00:51:54 +00:00
|
|
|
state = T.tensor(observation, dtype=T.float).to(self.actor.device)
|
2023-07-11 00:25:44 +00:00
|
|
|
dist = self.actor(state)
|
|
|
|
value = self.critic(state)
|
|
|
|
action = dist.sample()
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
probs = T.squeeze(dist.log_prob(action)).item()
|
|
|
|
action = T.squeeze(action).item()
|
|
|
|
value = T.squeeze(value).item()
|
2023-11-19 03:27:47 +00:00
|
|
|
|
2023-12-08 21:08:25 +00:00
|
|
|
self.entropy = dist.entropy().mean().item()
|
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
return action, probs, value
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
def learn(self):
|
2023-12-10 19:15:40 +00:00
|
|
|
for _ in tqdm(range(self.n_epochs),
|
|
|
|
desc='Learning...',
|
|
|
|
dynamic_ncols=True,
|
|
|
|
leave=False,
|
|
|
|
ascii=True):
|
|
|
|
|
2023-11-17 02:19:03 +00:00
|
|
|
state_arr, action_arr, old_probs_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.generate_batches()
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
values = vals_arr
|
2023-11-23 15:37:02 +00:00
|
|
|
advantage = np.zeros(len(reward_arr), dtype=np.float64)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
for t in range(len(reward_arr)-1):
|
|
|
|
discount = 1
|
|
|
|
a_t = 0
|
|
|
|
for k in range(t, len(reward_arr)-1):
|
2023-11-14 21:44:43 +00:00
|
|
|
a_t += discount * \
|
|
|
|
(reward_arr[k] + self.gamma*values[k+1]
|
2023-11-17 02:19:03 +00:00
|
|
|
* (1-int(dones_arr[k])) - values[k])
|
2023-07-11 00:25:44 +00:00
|
|
|
discount *= self.gamma * self.gae_lambda
|
|
|
|
advantage[t] = a_t
|
2023-11-14 21:44:43 +00:00
|
|
|
advantage = T.tensor(advantage).to(self.actor.device)
|
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
values = T.tensor(values).to(self.actor.device)
|
2023-11-17 02:19:03 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
for batch in batches:
|
2023-11-14 21:44:43 +00:00
|
|
|
states = T.tensor(state_arr[batch], dtype=T.float).to(
|
|
|
|
self.actor.device)
|
|
|
|
old_probs = T.tensor(old_probs_arr[batch]).to(
|
|
|
|
self.actor.device)
|
2023-07-11 00:25:44 +00:00
|
|
|
actions = T.tensor(action_arr[batch]).to(self.actor.device)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
dist = self.actor(states)
|
|
|
|
critic_value = self.critic(states)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
critic_value = T.squeeze(critic_value)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
new_probs = dist.log_prob(actions)
|
|
|
|
prob_ratio = new_probs.exp() / old_probs.exp()
|
|
|
|
weighted_probs = advantage[batch] * prob_ratio
|
2023-12-08 21:08:25 +00:00
|
|
|
|
2023-11-14 21:44:43 +00:00
|
|
|
weighted_clipped_probs = T.clamp(
|
|
|
|
prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch]
|
2023-12-08 21:08:25 +00:00
|
|
|
|
2023-11-29 10:53:30 +00:00
|
|
|
self.actor_loss = -T.min(weighted_probs,
|
|
|
|
weighted_clipped_probs).mean()
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
returns = advantage[batch] + values[batch]
|
2023-11-29 10:53:30 +00:00
|
|
|
self.critic_loss = (returns - critic_value)**2
|
|
|
|
self.critic_loss = self.critic_loss.mean()
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-12-08 21:08:25 +00:00
|
|
|
self.total_loss = self.actor_loss + 0.5 * \
|
|
|
|
self.critic_loss - self.entropy_coef*self.entropy
|
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
self.actor.optimizer.zero_grad()
|
|
|
|
self.critic.optimizer.zero_grad()
|
2023-11-29 10:53:30 +00:00
|
|
|
self.total_loss.backward()
|
2023-12-08 21:08:25 +00:00
|
|
|
|
2023-12-10 19:15:40 +00:00
|
|
|
T.nn.utils.clip_grad_norm_(
|
|
|
|
self.actor.parameters(), max_norm=2)
|
|
|
|
|
|
|
|
T.nn.utils.clip_grad_norm_(
|
|
|
|
self.critic.parameters(), max_norm=2)
|
2023-12-08 21:08:25 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
self.actor.optimizer.step()
|
|
|
|
self.critic.optimizer.step()
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
self.memory.clear_memory()
|