pneuma-pygame/agents/ppo/brain.py

120 lines
3.3 KiB
Python
Raw Normal View History

import os
import numpy as np
import torch as T
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
2023-11-14 21:44:43 +00:00
class PPOMemory:
def __init__(self, batch_size):
self.states = []
self.probs = []
self.vals = []
self.actions = []
self.rewards = []
self.dones = []
2023-11-14 21:44:43 +00:00
self.batch_size = batch_size
2023-11-14 21:44:43 +00:00
def generate_batches(self):
2023-11-13 18:09:41 +00:00
n_states = len(self.states)
batch_start = np.arange(0, n_states, self.batch_size)
2023-11-14 21:44:43 +00:00
indices = np.arange(n_states, dtype=np.int64)
np.random.shuffle(indices)
batches = [indices[i:i+self.batch_size] for i in batch_start]
2023-11-14 21:44:43 +00:00
return np.array(self.states),\
np.array(self.actions),\
np.array(self.probs),\
np.array(self.vals),\
np.array(self.rewards),\
np.array(self.dones),\
batches
def store_memory(self, state, action, probs, vals, reward, done):
self.states.append(state)
2023-11-17 02:19:03 +00:00
self.actions.append(action)
self.probs.append(probs)
self.vals.append(vals)
self.rewards.append(reward)
self.dones.append(done)
2023-11-14 21:44:43 +00:00
def clear_memory(self):
self.states = []
self.probs = []
self.vals = []
self.actions = []
self.rewards = []
self.dones = []
2023-11-14 21:44:43 +00:00
class ActorNetwork(nn.Module):
2023-11-14 21:44:43 +00:00
def __init__(self, input_dim, output_dim, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
super(ActorNetwork, self).__init__()
2023-11-14 21:44:43 +00:00
self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
2023-11-17 02:19:03 +00:00
self.actor = nn.Sequential(
2023-11-14 21:44:43 +00:00
nn.Linear(input_dim, fc1_dims),
nn.ReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.ReLU(),
nn.Linear(fc2_dims, output_dim),
nn.Softmax(dim=-1)
)
2023-11-19 03:27:47 +00:00
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
2023-11-14 21:44:43 +00:00
self.device = T.device('cuda:0' if T.cuda.is_available() else (
'mps' if T.backends.mps.is_available() else 'cpu'))
self.to(self.device)
def forward(self, state):
dist = self.actor(state)
dist = Categorical(dist)
2023-11-14 21:44:43 +00:00
return dist
def save_checkpoint(self):
T.save(self.state_dict(), self.checkpoint_file)
2023-11-14 21:44:43 +00:00
def load_checkpoint(self):
self.load_state_dict(T.load(self.checkpoint_file))
2023-11-14 21:44:43 +00:00
class CriticNetwork(nn.Module):
2023-11-14 21:44:43 +00:00
def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
super(CriticNetwork, self).__init__()
2023-11-14 21:44:43 +00:00
self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
2023-11-17 02:19:03 +00:00
self.critic = nn.Sequential(
2023-11-14 21:44:43 +00:00
nn.Linear(input_dims, fc1_dims),
nn.ReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.ReLU(),
nn.Linear(fc2_dims, 1)
)
2023-11-19 03:27:47 +00:00
self.optimizer = optim.Adam(self.parameters(), lr=alpha, eps=1e-5)
2023-11-14 21:44:43 +00:00
self.device = T.device('cuda:0' if T.cuda.is_available() else (
'mps' if T.backends.mps.is_available() else 'cpu'))
self.to(self.device)
2023-11-14 21:44:43 +00:00
def forward(self, state):
2023-11-14 21:44:43 +00:00
value = self.critic(state)
return value
2023-11-14 21:44:43 +00:00
def save_checkpoint(self):
T.save(self.state_dict(), self.checkpoint_file)
2023-11-14 21:44:43 +00:00
def load_checkpoint(self):
self.load_state_dict(T.load(self.checkpoint_file))