Once again player with rewards, and added clipping to the params
|
@ -8,12 +8,13 @@ class Agent:
|
|||
|
||||
def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003,
|
||||
policy_clip=0.2, batch_size=64, N=2048, n_epochs=10,
|
||||
gae_lambda=0.95, chkpt_dir='tmp/ppo'):
|
||||
gae_lambda=0.95, entropy_coef=0.001, chkpt_dir='tmp/ppo'):
|
||||
|
||||
self.gamma = gamma
|
||||
self.policy_clip = policy_clip
|
||||
self.n_epochs = n_epochs
|
||||
self.gae_lambda = gae_lambda
|
||||
self.entropy_coef = entropy_coef
|
||||
|
||||
self.actor = ActorNetwork(
|
||||
input_dims, n_actions, alpha, chkpt_dir=chkpt_dir)
|
||||
|
@ -44,6 +45,8 @@ class Agent:
|
|||
action = T.squeeze(action).item()
|
||||
value = T.squeeze(value).item()
|
||||
|
||||
self.entropy = dist.entropy().mean().item()
|
||||
|
||||
return action, probs, value
|
||||
|
||||
def learn(self):
|
||||
|
@ -81,8 +84,10 @@ class Agent:
|
|||
new_probs = dist.log_prob(actions)
|
||||
prob_ratio = new_probs.exp() / old_probs.exp()
|
||||
weighted_probs = advantage[batch] * prob_ratio
|
||||
|
||||
weighted_clipped_probs = T.clamp(
|
||||
prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch]
|
||||
|
||||
self.actor_loss = -T.min(weighted_probs,
|
||||
weighted_clipped_probs).mean()
|
||||
|
||||
|
@ -90,10 +95,27 @@ class Agent:
|
|||
self.critic_loss = (returns - critic_value)**2
|
||||
self.critic_loss = self.critic_loss.mean()
|
||||
|
||||
self.total_loss = self.actor_loss + 0.5*self.critic_loss
|
||||
self.total_loss = self.actor_loss + 0.5 * \
|
||||
self.critic_loss - self.entropy_coef*self.entropy
|
||||
|
||||
self.actor.optimizer.zero_grad()
|
||||
self.critic.optimizer.zero_grad()
|
||||
self.total_loss.backward()
|
||||
|
||||
# Calculate the gradient norms for both networks
|
||||
actor_grad_norm = T.nn.utils.clip_grad_norm_(
|
||||
self.actor.parameters(), max_norm=1)
|
||||
critic_grad_norm = T.nn.utils.clip_grad_norm_(
|
||||
self.critic.parameters(), max_norm=1)
|
||||
|
||||
T.nn.utils.clip_grad_norm_(
|
||||
self.actor.parameters(), max_norm=1)
|
||||
T.nn.utils.clip_grad_norm_(
|
||||
self.critic.parameters(), max_norm=1)
|
||||
# Log or print the gradient norms
|
||||
print(f"Actor Gradient Norm: {actor_grad_norm}")
|
||||
print(f"Critic Gradient Norm: {critic_grad_norm}")
|
||||
|
||||
self.actor.optimizer.step()
|
||||
self.critic.optimizer.step()
|
||||
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
# AI setup
|
||||
N = 20
|
||||
batch_size = 5
|
||||
n_epochs = 4
|
||||
alpha = 0.0003
|
||||
|
|
@ -5,40 +5,40 @@
|
|||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,392,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,500,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,390,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,393,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,500,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,392,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
|
|
|
|
@ -1,8 +1,8 @@
|
|||
monster_data = {
|
||||
'squid': {'id': 1,
|
||||
'health': .1,
|
||||
'exp': 1,
|
||||
'attack': .5,
|
||||
'health': 100,
|
||||
'exp': 10,
|
||||
'attack': 50,
|
||||
'attack_type': 'slash',
|
||||
'speed': 3,
|
||||
'knockback': 20,
|
||||
|
@ -10,19 +10,19 @@ monster_data = {
|
|||
'notice_radius': 360},
|
||||
|
||||
'raccoon': {'id': 2,
|
||||
'health': .3,
|
||||
'exp': 2.5,
|
||||
'attack': .8,
|
||||
'health': 300,
|
||||
'exp': 25,
|
||||
'attack': 80,
|
||||
'attack_type': 'claw',
|
||||
'speed': 2,
|
||||
'knockback': 20,
|
||||
'knockback': 10,
|
||||
'attack_radius': 120,
|
||||
'notice_radius': 400},
|
||||
|
||||
'spirit': {'id': 3,
|
||||
'health': .1,
|
||||
'exp': 1.1,
|
||||
'attack': .6,
|
||||
'health': 80,
|
||||
'exp': 11,
|
||||
'attack': 60,
|
||||
'attack_type': 'thunder',
|
||||
'speed': 4,
|
||||
'knockback': 20,
|
||||
|
@ -30,9 +30,9 @@ monster_data = {
|
|||
'notice_radius': 350},
|
||||
|
||||
'bamboo': {'id': 4,
|
||||
'health': .07,
|
||||
'exp': 1.2,
|
||||
'attack': .2,
|
||||
'health': 70,
|
||||
'exp': 9,
|
||||
'attack': 20,
|
||||
'attack_type': 'leaf_attack',
|
||||
'speed': 3,
|
||||
'knockback': 20,
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
tank_stats = {
|
||||
'role_id': 1,
|
||||
'health': 1.5,
|
||||
'energy': .4,
|
||||
'attack': .7,
|
||||
'magic': .3,
|
||||
'health': 150,
|
||||
'energy': 40,
|
||||
'attack': 10,
|
||||
'magic': 3,
|
||||
'speed': 3
|
||||
}
|
||||
|
||||
|
@ -11,7 +11,7 @@ mage_stats = {
|
|||
'role_id': 2,
|
||||
'health': 70,
|
||||
'energy': 80,
|
||||
'attack': 3,
|
||||
'attack': 7,
|
||||
'magic': 6,
|
||||
'speed': 5
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ warrior_stats = {
|
|||
'role_id': 3,
|
||||
'health': 100,
|
||||
'energy': 60,
|
||||
'attack': 10,
|
||||
'attack': 12,
|
||||
'magic': 4,
|
||||
'speed': 5
|
||||
}
|
||||
|
|
|
@ -61,6 +61,7 @@ class Player(pygame.sprite.Sprite):
|
|||
n_epochs,
|
||||
gae_lambda,
|
||||
chkpt_dir,
|
||||
entropy_coef,
|
||||
no_load=False):
|
||||
|
||||
self.max_num_enemies = len(self.distance_direction_from_enemy)
|
||||
|
@ -77,6 +78,7 @@ class Player(pygame.sprite.Sprite):
|
|||
N=N,
|
||||
n_epochs=n_epochs,
|
||||
gae_lambda=gae_lambda,
|
||||
entropy_coef=entropy_coef,
|
||||
chkpt_dir=chkpt_dir
|
||||
)
|
||||
print(
|
||||
|
@ -150,7 +152,7 @@ class Player(pygame.sprite.Sprite):
|
|||
|
||||
def fermi(x, a):
|
||||
# Used for rescaling features
|
||||
return 1 / (np.exp(-(x - a)) + 1)
|
||||
return 1 / (np.exp((x - a)) + 1)
|
||||
|
||||
def maxwell(x, a):
|
||||
# Used for rescaling features
|
||||
|
@ -166,27 +168,32 @@ class Player(pygame.sprite.Sprite):
|
|||
|
||||
self.action_features = [self._input.action]
|
||||
|
||||
self.reward_features = [
|
||||
self.stats.exp,
|
||||
# self.reward = [
|
||||
# np.log(1 + self.stats.exp),
|
||||
#
|
||||
# fermi(nearest_dist, 50),
|
||||
#
|
||||
# fermi(
|
||||
# nearest_enemy.stats.health,
|
||||
# nearest_enemy.stats.monster_info['health']
|
||||
# ),
|
||||
#
|
||||
# maxwell(
|
||||
# len(self.distance_direction_from_enemy),
|
||||
# self.max_num_enemies
|
||||
# ) - 1,
|
||||
#
|
||||
# - fermi(
|
||||
# self.stats.health,
|
||||
# self.stats.stats['health']
|
||||
# ),
|
||||
# ]
|
||||
|
||||
fermi(nearest_dist, 10),
|
||||
|
||||
fermi(
|
||||
nearest_enemy.stats.health,
|
||||
nearest_enemy.stats.monster_info['health']
|
||||
),
|
||||
|
||||
maxwell(
|
||||
len(self.distance_direction_from_enemy),
|
||||
self.max_num_enemies
|
||||
) - 1,
|
||||
|
||||
- fermi(
|
||||
self.stats.health,
|
||||
self.stats.stats['health']
|
||||
),
|
||||
|
||||
]
|
||||
self.reward = self.stats.exp\
|
||||
+ self.stats.health/self.stats.stats['health'] - 1\
|
||||
- nearest_dist/np.sqrt(np.sum(self.map_edge))\
|
||||
- nearest_enemy.stats.health/nearest_enemy.stats.monster_info['health']\
|
||||
- len(self.distance_direction_from_enemy)/self.max_num_enemies
|
||||
|
||||
self.state_features = [
|
||||
self.animation.rect.center[0]/self.map_edge[0],
|
||||
|
@ -204,7 +211,7 @@ class Player(pygame.sprite.Sprite):
|
|||
for distance, direction, enemy in self.distance_direction_from_enemy:
|
||||
enemy_states.extend([
|
||||
|
||||
fermi(distance, 10),
|
||||
distance/np.sqrt(np.sum(self.map_edge)),
|
||||
|
||||
direction[0],
|
||||
|
||||
|
@ -256,16 +263,11 @@ class Player(pygame.sprite.Sprite):
|
|||
|
||||
def update(self):
|
||||
|
||||
if not self.is_dead():
|
||||
self.agent_update()
|
||||
|
||||
self.agent_update()
|
||||
|
||||
# Cooldowns and Regen
|
||||
self.stats.health_recovery()
|
||||
self.stats.energy_recovery()
|
||||
|
||||
else:
|
||||
self.stats.exp = max(-1, self.stats.exp - .1)
|
||||
# Cooldowns and Regen
|
||||
self.stats.health_recovery()
|
||||
self.stats.energy_recovery()
|
||||
|
||||
# Refresh player based on input and animate
|
||||
self.get_status()
|
||||
|
|
Before Width: | Height: | Size: 51 KiB After Width: | Height: | Size: 13 KiB |
Before Width: | Height: | Size: 32 KiB After Width: | Height: | Size: 13 KiB |
BIN
figures/older_figures/actor_loss.png
Normal file
After Width: | Height: | Size: 24 KiB |
BIN
figures/older_figures/critic_loss.png
Normal file
After Width: | Height: | Size: 25 KiB |
BIN
figures/older_figures/score.png
Normal file
After Width: | Height: | Size: 55 KiB |
BIN
figures/older_figures/total_loss.png
Normal file
After Width: | Height: | Size: 24 KiB |
Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 15 KiB |
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 20 KiB |
50
level.py
|
@ -91,16 +91,16 @@ class Level:
|
|||
elif col == '700' and self.n_players > 1:
|
||||
print(f"Prison set at:{(x, y)}")
|
||||
# Generate grass
|
||||
if style == 'grass':
|
||||
random_grass_image = choice(self.graphics['grass'])
|
||||
|
||||
Terrain((x, y), [
|
||||
self.visible_sprites,
|
||||
self.obstacle_sprites,
|
||||
self.attackable_sprites
|
||||
],
|
||||
'grass',
|
||||
random_grass_image)
|
||||
# if style == 'grass':
|
||||
# random_grass_image = choice(self.graphics['grass'])
|
||||
#
|
||||
# Terrain((x, y), [
|
||||
# self.visible_sprites,
|
||||
# self.obstacle_sprites,
|
||||
# self.attackable_sprites
|
||||
# ],
|
||||
# 'grass',
|
||||
# random_grass_image)
|
||||
|
||||
# Generate objects like trees and statues
|
||||
# if style == 'objects':
|
||||
|
@ -171,18 +171,18 @@ class Level:
|
|||
if int(col) != -1:
|
||||
x = col_index * TILESIZE
|
||||
y = row_index * TILESIZE
|
||||
# Regenerate grass
|
||||
if style == 'grass':
|
||||
random_grass_image = choice(
|
||||
self.graphics['grass'])
|
||||
|
||||
Terrain((x, y), [
|
||||
self.visible_sprites,
|
||||
self.obstacle_sprites,
|
||||
self.attackable_sprites
|
||||
],
|
||||
'grass',
|
||||
random_grass_image)
|
||||
# # Regenerate grass
|
||||
# if style == 'grass':
|
||||
# random_grass_image = choice(
|
||||
# self.graphics['grass'])
|
||||
#
|
||||
# Terrain((x, y), [
|
||||
# self.visible_sprites,
|
||||
# self.obstacle_sprites,
|
||||
# self.attackable_sprites
|
||||
# ],
|
||||
# 'grass',
|
||||
# random_grass_image)
|
||||
|
||||
if style == 'entities':
|
||||
|
||||
|
@ -309,7 +309,11 @@ class Level:
|
|||
debug('PAUSED')
|
||||
|
||||
for player in self.player_sprites:
|
||||
self.dead_players[player.player_id] = player.is_dead()
|
||||
if player.is_dead():
|
||||
print('Player dead')
|
||||
player.stats.exp = -10
|
||||
player.update()
|
||||
self.dead_players[player.player_id] = player.is_dead()
|
||||
|
||||
self.done = True if (self.dead_players.all() == 1
|
||||
or self.enemy_sprites == []) else False
|
||||
|
|
13
pneuma.py
|
@ -73,6 +73,11 @@ if __name__ == "__main__":
|
|||
default=0.99,
|
||||
help="The gamma parameter for PPO")
|
||||
|
||||
parser.add_argument('--entropy',
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="The entropy coefficient")
|
||||
|
||||
parser.add_argument('--alpha',
|
||||
type=float,
|
||||
default=0.0003,
|
||||
|
@ -119,6 +124,7 @@ if __name__ == "__main__":
|
|||
|
||||
# Setup AI stuff
|
||||
score_history = np.zeros(shape=(n_players, n_episodes))
|
||||
|
||||
best_score = np.zeros(n_players)
|
||||
|
||||
actor_loss = np.zeros(shape=(n_players,
|
||||
|
@ -142,6 +148,7 @@ if __name__ == "__main__":
|
|||
N=args.horizon,
|
||||
n_epochs=args.n_epochs,
|
||||
gae_lambda=args.gae_lambda,
|
||||
entropy_coef=args.entropy,
|
||||
chkpt_dir=chkpt_path,
|
||||
no_load=args.no_load
|
||||
)
|
||||
|
@ -189,7 +196,7 @@ if __name__ == "__main__":
|
|||
# Gather information about the episode
|
||||
for player in game.level.player_sprites:
|
||||
|
||||
score = np.mean(player.reward_features)
|
||||
score = player.reward
|
||||
|
||||
# Update score
|
||||
score_history[player.player_id][episode] = np.mean(score)
|
||||
|
@ -223,10 +230,10 @@ if __name__ == "__main__":
|
|||
print(f"Models saved to {chkpt_path}")
|
||||
|
||||
plt.figure()
|
||||
plt.title("Player Performance")
|
||||
plt.title("Agent Rewards")
|
||||
plt.xlabel("Episode")
|
||||
plt.ylabel("Score")
|
||||
plt.legend([f"Player {num}" for num in range(n_players)])
|
||||
plt.legend([f"Agent {num}" for num in range(n_players)])
|
||||
for player_score in score_history:
|
||||
plt.plot(player_score)
|
||||
plt.savefig(f"{figure_folder}/score.png")
|
||||
|
|