139 lines
4.7 KiB
Python
139 lines
4.7 KiB
Python
import os
|
|
import torch as T
|
|
import otrch.nn.functional as F
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.distributions.normal import Normal
|
|
import numpy as np
|
|
|
|
class CriticNetwork(nn.Module):
|
|
def __init__(self, beta, input_dims, n_actions, name='critic', chkpt_dir='tmp/sac'):
|
|
super(CriticNetwork, self).__init__()
|
|
self.input_dims = input_dims
|
|
self.fc_size = input_dims // 2
|
|
self.n_actions = n_actions
|
|
self.name = name
|
|
self.chkpt_dir = chkpt_dir
|
|
self.chkpt_file = os.path.join(self.chkpt_dir, name+'_sac')
|
|
|
|
self.input_layer = nn.Linear(self.input_dims[0]+n_actions, self.fc_size)
|
|
self.middle_layer = nn.Linear(self.fc_size, self.fc_size)
|
|
self.output_layer = nn.Linear(self.fc_dims, 1)
|
|
|
|
self.optimizer = optim.Adam(self.parameters(), lr=beta, betas=(0.9, 0.9))
|
|
self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
|
|
|
|
self.to(self.device)
|
|
|
|
def forward(self, state, action):
|
|
action_value = self.input_layer(T.cat([state, action], dim=1))
|
|
action_calue = F.tanh(action_value)
|
|
action_value = self.middle_layer(action_value)
|
|
action_calue = F.tanh(action_value)
|
|
action_value = self.middle_layer(action_value)
|
|
action_calue = F.tanh(action_value)
|
|
action_value = self.middle_layer(action_value)
|
|
action_value = F.tanh(action_value)
|
|
|
|
q = self.output_layer(action_value)
|
|
|
|
return q
|
|
|
|
def save(self):
|
|
T.save(self.state_dict(), self.chkpt_file)
|
|
|
|
def load(self):
|
|
self.load_state_dict(T.load(self.chkpt_file))
|
|
|
|
class ValueNetwork(nn.Module):
|
|
def __init__(self, beta, input_dims, name='value', chkpt_dir='tmp/sac'):
|
|
super(ValueNetwork, self).__init__()
|
|
self.input_dims = input_dims
|
|
self.fc_size = input_dims // 2
|
|
self.name = name
|
|
self.chkpt_dir = chkpt_dir
|
|
self.chkpt_file = os.path.join(self.chkpt_dir, name+'_sac')
|
|
|
|
self.input_layer = nn.Linear(*self.input_dims, self.fc_size)
|
|
self.middle_layer = nn.Linear(self.fc_size, self.fc_size)
|
|
self.output_layer = nn.Linear(self.fc_size, 1)
|
|
|
|
self.optimizer = optim.Adam(self.parameters(), lr=beta, betas=(0.9, 0.9))
|
|
self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
|
|
|
|
self.to(self.device)
|
|
|
|
def forward(self, state):
|
|
state_value = self.input_layer(state)
|
|
state_value = F.tanh(state_value)
|
|
state_value = self.middle_layer(state)
|
|
state_value = F.tanh(state_value)
|
|
|
|
v = self.output_layer(state_value)
|
|
|
|
return v
|
|
|
|
def save(self):
|
|
T.save(self.state_dict(), self.chkpt_file)
|
|
|
|
def load(self):
|
|
self.load_state_dict(T.load(self.chkpt_file))
|
|
|
|
class ActorNetwork(nn.Module):
|
|
def __init__(self, alpha, input_dims, n_actions, max_action, name='actor', chkpt_dir='tmp/sac'):
|
|
super(ActorNetwork, self).__init__()
|
|
self.input_dims = input_dims
|
|
self.fc_size = input_dims // 4
|
|
self.n_actions = n_actions
|
|
self.max_action = max_action
|
|
self.name = name
|
|
self.chkpt_dir = chkpt_dir
|
|
self.chkpt_file = os.path.join(self.chkpt_dir, name+'_sac')
|
|
|
|
self.reparam_noise = 1e-6
|
|
|
|
self.input_layer = nn.Linear(*self.input_dims, self.fc_size)
|
|
self.middle_layer = nn.Linear(self.fc_size, self.fc_size)
|
|
self.mu = nn.Linear(self.fc_size, self.n_actions)
|
|
self.sigma = nn.Linear(self.fc_size, self.n_actions)
|
|
|
|
self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9))
|
|
self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
|
|
|
|
self.to(self.device)
|
|
|
|
def forward(self, state):
|
|
prob = self.input_layer(state)
|
|
prob = F.tanh(prob)
|
|
prob = self.middle_layer(prob)
|
|
prob = F.tanh(prob)
|
|
prob = self.middle_layer(prob)
|
|
|
|
mu = self.mu(prob)
|
|
sigma = self.sigma(prob)
|
|
|
|
sigma = T.clamp(sigma, min=self.reparam_noise, max=1)
|
|
|
|
return mu, sigma
|
|
|
|
def sample_normal(self, state, reparametrize=True):
|
|
mu, sigma = self.forward(state)
|
|
probabilities = Normal(mu, sigma)
|
|
|
|
if reparametrize:
|
|
actions = probabilities.rsample()
|
|
else:
|
|
actions = probabilities.sample()
|
|
|
|
action = T.tanh(actions)*T.tensor(self.max_action).to(self.device)
|
|
log_probs = probabilities.log_prob(actions)
|
|
log_probs = T.log(1-action.pow(2)+self.reparam_noise)
|
|
log_probs = log_probs.sum(1, keepdim=True)
|
|
|
|
return action, log_probs
|
|
|
|
def save(self):
|
|
T.save(self.state_dict(), self.chkpt_file)
|
|
|
|
def load(self):
|
|
self.load_state_dict(T.load(self.chkpt_file))
|