pneuma-pygame/Godot/policy.py
2024-05-17 01:16:20 +02:00

38 lines
1 KiB
Python

from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from gymnasium import spaces
import torch as th
from torch import nn
from stable_baselines3 import PPO
from stable_baselines3.common.policies import MultiInputPolicy
class CustomACPolicy(nn.Module):
"""
Custom network for policy and value functions.
It receives as input the number of layers for each network, the activation function and the optimizer parameters.
"""
def __init__(
self,
feature_dim: int,
last_layer_dim_pi: int = 64,
last_layer_dim_vf: int = 64,
):
super().__init__()
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Policy network
self.policy_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_pi),
nn.Tanh()
)
# Value network
self.value_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_vf),
nn.Tanh()
)