2023-07-11 00:25:44 +00:00
|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
import torch as T
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.optim as optim
|
|
|
|
from torch.distributions.categorical import Categorical
|
|
|
|
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
class PPOMemory:
|
|
|
|
def __init__(self, batch_size):
|
|
|
|
self.states = []
|
|
|
|
self.probs = []
|
|
|
|
self.vals = []
|
|
|
|
self.actions = []
|
|
|
|
self.rewards = []
|
|
|
|
self.dones = []
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
self.batch_size = batch_size
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
def generate_batches(self):
|
2023-11-13 18:09:41 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
n_states = len(self.states)
|
|
|
|
batch_start = np.arange(0, n_states, self.batch_size)
|
2023-11-14 21:44:43 +00:00
|
|
|
indices = np.arange(n_states, dtype=np.int64)
|
2023-07-11 00:25:44 +00:00
|
|
|
np.random.shuffle(indices)
|
|
|
|
batches = [indices[i:i+self.batch_size] for i in batch_start]
|
2023-11-14 21:44:43 +00:00
|
|
|
|
|
|
|
return np.array(self.states),\
|
|
|
|
np.array(self.actions),\
|
|
|
|
np.array(self.probs),\
|
|
|
|
np.array(self.vals),\
|
|
|
|
np.array(self.rewards),\
|
|
|
|
np.array(self.dones),\
|
|
|
|
batches
|
2023-07-11 00:25:44 +00:00
|
|
|
|
|
|
|
def store_memory(self, state, action, probs, vals, reward, done):
|
|
|
|
self.states.append(state)
|
2023-11-17 02:19:03 +00:00
|
|
|
self.actions.append(action)
|
2023-07-11 00:25:44 +00:00
|
|
|
self.probs.append(probs)
|
|
|
|
self.vals.append(vals)
|
|
|
|
self.rewards.append(reward)
|
|
|
|
self.dones.append(done)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
def clear_memory(self):
|
|
|
|
self.states = []
|
|
|
|
self.probs = []
|
|
|
|
self.vals = []
|
|
|
|
self.actions = []
|
|
|
|
self.rewards = []
|
|
|
|
self.dones = []
|
|
|
|
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
class ActorNetwork(nn.Module):
|
|
|
|
|
2024-03-11 10:44:20 +00:00
|
|
|
def __init__(self, input_dim, output_dim, alpha, fc1_dims=1024, fc2_dims=1024, chkpt_dir='tmp'):
|
2023-07-11 00:25:44 +00:00
|
|
|
super(ActorNetwork, self).__init__()
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-11-23 15:37:02 +00:00
|
|
|
self.chkpt_dir = chkpt_dir
|
2023-11-17 02:19:03 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
self.actor = nn.Sequential(
|
2023-11-14 21:44:43 +00:00
|
|
|
nn.Linear(input_dim, fc1_dims),
|
2023-12-10 19:15:40 +00:00
|
|
|
nn.LeakyReLU(),
|
2023-11-14 21:44:43 +00:00
|
|
|
nn.Linear(fc1_dims, fc2_dims),
|
2023-12-10 19:15:40 +00:00
|
|
|
nn.LeakyReLU(),
|
2024-02-10 17:26:54 +00:00
|
|
|
nn.Linear(fc1_dims, fc2_dims),
|
|
|
|
nn.LeakyReLU(),
|
2023-11-14 21:44:43 +00:00
|
|
|
nn.Linear(fc2_dims, output_dim),
|
|
|
|
nn.Softmax(dim=-1)
|
|
|
|
)
|
|
|
|
|
2024-02-29 17:07:31 +00:00
|
|
|
self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2024-03-14 18:14:16 +00:00
|
|
|
self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
self.to(self.device)
|
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
dist = self.actor(state)
|
|
|
|
dist = Categorical(dist)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
return dist
|
|
|
|
|
2024-02-29 17:07:31 +00:00
|
|
|
def save_checkpoint(self, filename):
|
|
|
|
T.save(self.state_dict(), os.path.join(filename))
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2024-02-29 17:07:31 +00:00
|
|
|
def load_checkpoint(self, filename):
|
|
|
|
print(filename)
|
2023-11-29 10:53:30 +00:00
|
|
|
self.load_state_dict(
|
2024-02-29 17:07:31 +00:00
|
|
|
T.load(os.path.join(filename),
|
2023-11-29 10:53:30 +00:00
|
|
|
map_location=self.device))
|
2023-07-11 00:25:44 +00:00
|
|
|
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
class CriticNetwork(nn.Module):
|
|
|
|
|
2024-03-11 10:44:20 +00:00
|
|
|
def __init__(self, input_dims, alpha, fc1_dims=4096, fc2_dims=4096, chkpt_dir='tmp'):
|
2023-07-11 00:25:44 +00:00
|
|
|
super(CriticNetwork, self).__init__()
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-11-23 15:37:02 +00:00
|
|
|
self.chkpt_dir = chkpt_dir
|
2023-11-17 02:19:03 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
self.critic = nn.Sequential(
|
2023-11-14 21:44:43 +00:00
|
|
|
nn.Linear(input_dims, fc1_dims),
|
2023-12-09 12:48:16 +00:00
|
|
|
nn.LeakyReLU(),
|
2023-11-14 21:44:43 +00:00
|
|
|
nn.Linear(fc1_dims, fc2_dims),
|
2023-12-09 12:48:16 +00:00
|
|
|
nn.LeakyReLU(),
|
2024-02-10 17:26:54 +00:00
|
|
|
nn.Linear(fc1_dims, fc2_dims),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
nn.Linear(fc1_dims, fc2_dims),
|
|
|
|
nn.LeakyReLU(),
|
2024-03-11 10:44:20 +00:00
|
|
|
nn.Linear(fc1_dims, fc2_dims),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
nn.Linear(fc1_dims, fc2_dims),
|
|
|
|
nn.LeakyReLU(),
|
2024-02-10 17:26:54 +00:00
|
|
|
nn.Linear(fc1_dims, fc2_dims),
|
|
|
|
nn.LeakyReLU(),
|
2023-11-14 21:44:43 +00:00
|
|
|
nn.Linear(fc2_dims, 1)
|
|
|
|
)
|
|
|
|
|
2024-02-29 17:07:31 +00:00
|
|
|
self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9), eps=1e-5)
|
2023-11-19 03:27:47 +00:00
|
|
|
|
2024-03-14 18:14:16 +00:00
|
|
|
self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
self.to(self.device)
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2023-07-11 00:25:44 +00:00
|
|
|
def forward(self, state):
|
2023-11-14 21:44:43 +00:00
|
|
|
value = self.critic(state)
|
2023-07-11 00:25:44 +00:00
|
|
|
return value
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2024-02-29 17:07:31 +00:00
|
|
|
def save_checkpoint(self, filename):
|
|
|
|
T.save(self.state_dict(), os.path.join(filename))
|
2023-11-14 21:44:43 +00:00
|
|
|
|
2024-02-29 17:07:31 +00:00
|
|
|
def load_checkpoint(self, filename):
|
|
|
|
print(filename)
|
2023-11-29 10:53:30 +00:00
|
|
|
self.load_state_dict(
|
2024-02-29 17:07:31 +00:00
|
|
|
T.load(os.path.join(filename),
|
2023-11-29 10:53:30 +00:00
|
|
|
map_location=self.device))
|