lyceum-env/students/sac/brain.py

140 lines
4.7 KiB
Python
Raw Normal View History

2024-10-03 08:15:39 +00:00
import os
import torch as T
import otrch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
import numpy as np
class CriticNetwork(nn.Module):
def __init__(self, beta, input_dims, n_actions, name='critic', chkpt_dir='tmp/sac'):
super(CriticNetwork, self).__init__()
self.input_dims = input_dims
self.fc_size = input_dims // 2
self.n_actions = n_actions
self.name = name
self.chkpt_dir = chkpt_dir
self.chkpt_file = os.path.join(self.chkpt_dir, name+'_sac')
self.input_layer = nn.Linear(self.input_dims[0]+n_actions, self.fc_size)
self.middle_layer = nn.Linear(self.fc_size, self.fc_size)
self.output_layer = nn.Linear(self.fc_dims, 1)
self.optimizer = optim.Adam(self.parameters(), lr=beta, betas=(0.9, 0.9))
self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
self.to(self.device)
def forward(self, state, action):
action_value = self.input_layer(T.cat([state, action], dim=1))
action_calue = F.tanh(action_value)
action_value = self.middle_layer(action_value)
action_calue = F.tanh(action_value)
action_value = self.middle_layer(action_value)
action_calue = F.tanh(action_value)
action_value = self.middle_layer(action_value)
action_value = F.tanh(action_value)
q = self.output_layer(action_value)
return q
def save(self):
T.save(self.state_dict(), self.chkpt_file)
def load(self):
self.load_state_dict(T.load(self.chkpt_file))
class ValueNetwork(nn.Module):
def __init__(self, beta, input_dims, name='value', chkpt_dir='tmp/sac'):
super(ValueNetwork, self).__init__()
self.input_dims = input_dims
self.fc_size = input_dims // 2
self.name = name
self.chkpt_dir = chkpt_dir
self.chkpt_file = os.path.join(self.chkpt_dir, name+'_sac')
self.input_layer = nn.Linear(*self.input_dims, self.fc_size)
self.middle_layer = nn.Linear(self.fc_size, self.fc_size)
self.output_layer = nn.Linear(self.fc_size, 1)
self.optimizer = optim.Adam(self.parameters(), lr=beta, betas=(0.9, 0.9))
self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
self.to(self.device)
def forward(self, state):
state_value = self.input_layer(state)
state_value = F.tanh(state_value)
state_value = self.middle_layer(state)
state_value = F.tanh(state_value)
v = self.output_layer(state_value)
return v
def save(self):
T.save(self.state_dict(), self.chkpt_file)
def load(self):
self.load_state_dict(T.load(self.chkpt_file))
class ActorNetwork(nn.Module):
def __init__(self, alpha, input_dims, n_actions, max_action, name='actor', chkpt_dir='tmp/sac'):
super(ActorNetwork, self).__init__()
self.input_dims = input_dims
self.fc_size = input_dims // 4
self.n_actions = n_actions
self.max_action = max_action
self.name = name
self.chkpt_dir = chkpt_dir
self.chkpt_file = os.path.join(self.chkpt_dir, name+'_sac')
self.reparam_noise = 1e-6
self.input_layer = nn.Linear(*self.input_dims, self.fc_size)
self.middle_layer = nn.Linear(self.fc_size, self.fc_size)
self.mu = nn.Linear(self.fc_size, self.n_actions)
self.sigma = nn.Linear(self.fc_size, self.n_actions)
self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9))
self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
self.to(self.device)
def forward(self, state):
prob = self.input_layer(state)
prob = F.tanh(prob)
prob = self.middle_layer(prob)
prob = F.tanh(prob)
prob = self.middle_layer(prob)
mu = self.mu(prob)
sigma = self.sigma(prob)
sigma = T.clamp(sigma, min=self.reparam_noise, max=1)
return mu, sigma
def sample_normal(self, state, reparametrize=True):
mu, sigma = self.forward(state)
probabilities = Normal(mu, sigma)
if reparametrize:
actions = probabilities.rsample()
else:
actions = probabilities.sample()
action = T.tanh(actions)*T.tensor(self.max_action).to(self.device)
log_probs = probabilities.log_prob(actions)
log_probs = T.log(1-action.pow(2)+self.reparam_noise)
log_probs = log_probs.sum(1, keepdim=True)
return action, log_probs
def save(self):
T.save(self.state_dict(), self.chkpt_file)
def load(self):
self.load_state_dict(T.load(self.chkpt_file))