Increased size of NNs

This commit is contained in:
Vasilis Valatsos 2024-02-10 18:26:54 +01:00
parent ff8fa7d9e7
commit e9a182d7df

View file

@ -52,7 +52,7 @@ class PPOMemory:
class ActorNetwork(nn.Module): class ActorNetwork(nn.Module):
def __init__(self, input_dim, output_dim, alpha, fc1_dims=512, fc2_dims=512, chkpt_dir='tmp/ppo'): def __init__(self, input_dim, output_dim, alpha, fc1_dims=1024, fc2_dims=1024, chkpt_dir='tmp/ppo'):
super(ActorNetwork, self).__init__() super(ActorNetwork, self).__init__()
self.chkpt_dir = chkpt_dir self.chkpt_dir = chkpt_dir
@ -62,6 +62,8 @@ class ActorNetwork(nn.Module):
nn.LeakyReLU(), nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims), nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(), nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(),
nn.Linear(fc2_dims, output_dim), nn.Linear(fc2_dims, output_dim),
nn.Softmax(dim=-1) nn.Softmax(dim=-1)
) )
@ -89,7 +91,7 @@ class ActorNetwork(nn.Module):
class CriticNetwork(nn.Module): class CriticNetwork(nn.Module):
def __init__(self, input_dims, alpha, fc1_dims=512, fc2_dims=512, chkpt_dir='tmp/ppo'): def __init__(self, input_dims, alpha, fc1_dims=4096, fc2_dims=4096, chkpt_dir='tmp/ppo'):
super(CriticNetwork, self).__init__() super(CriticNetwork, self).__init__()
self.chkpt_dir = chkpt_dir self.chkpt_dir = chkpt_dir
@ -99,6 +101,16 @@ class CriticNetwork(nn.Module):
nn.LeakyReLU(), nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims), nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(), nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(),
nn.Linear(fc1_dims, fc2_dims),
nn.LeakyReLU(),
nn.Linear(fc2_dims, 1) nn.Linear(fc2_dims, 1)
) )