38 lines
1 KiB
Python
38 lines
1 KiB
Python
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
|
|
|
from gymnasium import spaces
|
|
import torch as th
|
|
from torch import nn
|
|
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.policies import MultiInputPolicy
|
|
|
|
class CustomACPolicy(nn.Module):
|
|
"""
|
|
Custom network for policy and value functions.
|
|
|
|
It receives as input the number of layers for each network, the activation function and the optimizer parameters.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
feature_dim: int,
|
|
last_layer_dim_pi: int = 64,
|
|
last_layer_dim_vf: int = 64,
|
|
):
|
|
super().__init__()
|
|
|
|
self.latent_dim_pi = last_layer_dim_pi
|
|
self.latent_dim_vf = last_layer_dim_vf
|
|
|
|
# Policy network
|
|
self.policy_net = nn.Sequential(
|
|
nn.Linear(feature_dim, last_layer_dim_pi),
|
|
nn.Tanh()
|
|
)
|
|
|
|
# Value network
|
|
self.value_net = nn.Sequential(
|
|
nn.Linear(feature_dim, last_layer_dim_vf),
|
|
nn.Tanh()
|
|
)
|