from typing import Callable, Dict, List, Optional, Tuple, Type, Union from gymnasium import spaces import torch as th from torch import nn from stable_baselines3 import PPO from stable_baselines3.common.policies import MultiInputPolicy class CustomACPolicy(nn.Module): """ Custom network for policy and value functions. It receives as input the number of layers for each network, the activation function and the optimizer parameters. """ def __init__( self, feature_dim: int, last_layer_dim_pi: int = 64, last_layer_dim_vf: int = 64, ): super().__init__() self.latent_dim_pi = last_layer_dim_pi self.latent_dim_vf = last_layer_dim_vf # Policy network self.policy_net = nn.Sequential( nn.Linear(feature_dim, last_layer_dim_pi), nn.Tanh() ) # Value network self.value_net = nn.Sequential( nn.Linear(feature_dim, last_layer_dim_vf), nn.Tanh() )