diff --git a/agents/ppo/agent.py b/agents/ppo/agent.py index f287f78..ebc7bba 100644 --- a/agents/ppo/agent.py +++ b/agents/ppo/agent.py @@ -8,12 +8,13 @@ class Agent: def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003, policy_clip=0.2, batch_size=64, N=2048, n_epochs=10, - gae_lambda=0.95, chkpt_dir='tmp/ppo'): + gae_lambda=0.95, entropy_coef=0.001, chkpt_dir='tmp/ppo'): self.gamma = gamma self.policy_clip = policy_clip self.n_epochs = n_epochs self.gae_lambda = gae_lambda + self.entropy_coef = entropy_coef self.actor = ActorNetwork( input_dims, n_actions, alpha, chkpt_dir=chkpt_dir) @@ -44,6 +45,8 @@ class Agent: action = T.squeeze(action).item() value = T.squeeze(value).item() + self.entropy = dist.entropy().mean().item() + return action, probs, value def learn(self): @@ -81,8 +84,10 @@ class Agent: new_probs = dist.log_prob(actions) prob_ratio = new_probs.exp() / old_probs.exp() weighted_probs = advantage[batch] * prob_ratio + weighted_clipped_probs = T.clamp( prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch] + self.actor_loss = -T.min(weighted_probs, weighted_clipped_probs).mean() @@ -90,10 +95,27 @@ class Agent: self.critic_loss = (returns - critic_value)**2 self.critic_loss = self.critic_loss.mean() - self.total_loss = self.actor_loss + 0.5*self.critic_loss + self.total_loss = self.actor_loss + 0.5 * \ + self.critic_loss - self.entropy_coef*self.entropy + self.actor.optimizer.zero_grad() self.critic.optimizer.zero_grad() self.total_loss.backward() + + # Calculate the gradient norms for both networks + actor_grad_norm = T.nn.utils.clip_grad_norm_( + self.actor.parameters(), max_norm=1) + critic_grad_norm = T.nn.utils.clip_grad_norm_( + self.critic.parameters(), max_norm=1) + + T.nn.utils.clip_grad_norm_( + self.actor.parameters(), max_norm=1) + T.nn.utils.clip_grad_norm_( + self.critic.parameters(), max_norm=1) + # Log or print the gradient norms + print(f"Actor Gradient Norm: {actor_grad_norm}") + print(f"Critic Gradient Norm: {critic_grad_norm}") + self.actor.optimizer.step() self.critic.optimizer.step() diff --git a/agents/ppo/utils/hyperparams.py b/agents/ppo/utils/hyperparams.py deleted file mode 100644 index aca1cfd..0000000 --- a/agents/ppo/utils/hyperparams.py +++ /dev/null @@ -1,6 +0,0 @@ -# AI setup -N = 20 -batch_size = 5 -n_epochs = 4 -alpha = 0.0003 - diff --git a/agents/saved_models/A0 b/agents/saved_models/A0 deleted file mode 100644 index 86a47df..0000000 Binary files a/agents/saved_models/A0 and /dev/null differ diff --git a/agents/saved_models/A1 b/agents/saved_models/A1 deleted file mode 100644 index eb35b63..0000000 Binary files a/agents/saved_models/A1 and /dev/null differ diff --git a/agents/saved_models/C0 b/agents/saved_models/C0 deleted file mode 100644 index 23a1fbf..0000000 Binary files a/agents/saved_models/C0 and /dev/null differ diff --git a/agents/saved_models/C1 b/agents/saved_models/C1 deleted file mode 100644 index 17bd2a1..0000000 Binary files a/agents/saved_models/C1 and /dev/null differ diff --git a/assets/map/Entities.csv b/assets/map/Entities.csv index 5ac0ad4..54888e1 100644 --- a/assets/map/Entities.csv +++ b/assets/map/Entities.csv @@ -5,40 +5,40 @@ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,392,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,500,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,390,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,393,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,500,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 --1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,392,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 diff --git a/configs/game/monster_config.py b/configs/game/monster_config.py index 890097f..cf1c451 100644 --- a/configs/game/monster_config.py +++ b/configs/game/monster_config.py @@ -1,8 +1,8 @@ monster_data = { 'squid': {'id': 1, - 'health': .1, - 'exp': 1, - 'attack': .5, + 'health': 100, + 'exp': 10, + 'attack': 50, 'attack_type': 'slash', 'speed': 3, 'knockback': 20, @@ -10,19 +10,19 @@ monster_data = { 'notice_radius': 360}, 'raccoon': {'id': 2, - 'health': .3, - 'exp': 2.5, - 'attack': .8, + 'health': 300, + 'exp': 25, + 'attack': 80, 'attack_type': 'claw', 'speed': 2, - 'knockback': 20, + 'knockback': 10, 'attack_radius': 120, 'notice_radius': 400}, 'spirit': {'id': 3, - 'health': .1, - 'exp': 1.1, - 'attack': .6, + 'health': 80, + 'exp': 11, + 'attack': 60, 'attack_type': 'thunder', 'speed': 4, 'knockback': 20, @@ -30,9 +30,9 @@ monster_data = { 'notice_radius': 350}, 'bamboo': {'id': 4, - 'health': .07, - 'exp': 1.2, - 'attack': .2, + 'health': 70, + 'exp': 9, + 'attack': 20, 'attack_type': 'leaf_attack', 'speed': 3, 'knockback': 20, diff --git a/configs/game/player_config.py b/configs/game/player_config.py index a683841..a3471a9 100644 --- a/configs/game/player_config.py +++ b/configs/game/player_config.py @@ -1,9 +1,9 @@ tank_stats = { 'role_id': 1, - 'health': 1.5, - 'energy': .4, - 'attack': .7, - 'magic': .3, + 'health': 150, + 'energy': 40, + 'attack': 10, + 'magic': 3, 'speed': 3 } @@ -11,7 +11,7 @@ mage_stats = { 'role_id': 2, 'health': 70, 'energy': 80, - 'attack': 3, + 'attack': 7, 'magic': 6, 'speed': 5 } @@ -20,7 +20,7 @@ warrior_stats = { 'role_id': 3, 'health': 100, 'energy': 60, - 'attack': 10, + 'attack': 12, 'magic': 4, 'speed': 5 } diff --git a/entities/player.py b/entities/player.py index c0b1b2f..ca2641e 100644 --- a/entities/player.py +++ b/entities/player.py @@ -61,6 +61,7 @@ class Player(pygame.sprite.Sprite): n_epochs, gae_lambda, chkpt_dir, + entropy_coef, no_load=False): self.max_num_enemies = len(self.distance_direction_from_enemy) @@ -77,6 +78,7 @@ class Player(pygame.sprite.Sprite): N=N, n_epochs=n_epochs, gae_lambda=gae_lambda, + entropy_coef=entropy_coef, chkpt_dir=chkpt_dir ) print( @@ -150,7 +152,7 @@ class Player(pygame.sprite.Sprite): def fermi(x, a): # Used for rescaling features - return 1 / (np.exp(-(x - a)) + 1) + return 1 / (np.exp((x - a)) + 1) def maxwell(x, a): # Used for rescaling features @@ -166,27 +168,32 @@ class Player(pygame.sprite.Sprite): self.action_features = [self._input.action] - self.reward_features = [ - self.stats.exp, + # self.reward = [ + # np.log(1 + self.stats.exp), + # + # fermi(nearest_dist, 50), + # + # fermi( + # nearest_enemy.stats.health, + # nearest_enemy.stats.monster_info['health'] + # ), + # + # maxwell( + # len(self.distance_direction_from_enemy), + # self.max_num_enemies + # ) - 1, + # + # - fermi( + # self.stats.health, + # self.stats.stats['health'] + # ), + # ] - fermi(nearest_dist, 10), - - fermi( - nearest_enemy.stats.health, - nearest_enemy.stats.monster_info['health'] - ), - - maxwell( - len(self.distance_direction_from_enemy), - self.max_num_enemies - ) - 1, - - - fermi( - self.stats.health, - self.stats.stats['health'] - ), - - ] + self.reward = self.stats.exp\ + + self.stats.health/self.stats.stats['health'] - 1\ + - nearest_dist/np.sqrt(np.sum(self.map_edge))\ + - nearest_enemy.stats.health/nearest_enemy.stats.monster_info['health']\ + - len(self.distance_direction_from_enemy)/self.max_num_enemies self.state_features = [ self.animation.rect.center[0]/self.map_edge[0], @@ -204,7 +211,7 @@ class Player(pygame.sprite.Sprite): for distance, direction, enemy in self.distance_direction_from_enemy: enemy_states.extend([ - fermi(distance, 10), + distance/np.sqrt(np.sum(self.map_edge)), direction[0], @@ -256,16 +263,11 @@ class Player(pygame.sprite.Sprite): def update(self): - if not self.is_dead(): + self.agent_update() - self.agent_update() - - # Cooldowns and Regen - self.stats.health_recovery() - self.stats.energy_recovery() - - else: - self.stats.exp = max(-1, self.stats.exp - .1) + # Cooldowns and Regen + self.stats.health_recovery() + self.stats.energy_recovery() # Refresh player based on input and animate self.get_status() diff --git a/figures/actor_loss.png b/figures/actor_loss.png index 7894186..031ad58 100644 Binary files a/figures/actor_loss.png and b/figures/actor_loss.png differ diff --git a/figures/critic_loss.png b/figures/critic_loss.png index b000a1c..c6cac6b 100644 Binary files a/figures/critic_loss.png and b/figures/critic_loss.png differ diff --git a/figures/older_figures/actor_loss.png b/figures/older_figures/actor_loss.png new file mode 100644 index 0000000..58db36e Binary files /dev/null and b/figures/older_figures/actor_loss.png differ diff --git a/figures/older_figures/critic_loss.png b/figures/older_figures/critic_loss.png new file mode 100644 index 0000000..a364bc1 Binary files /dev/null and b/figures/older_figures/critic_loss.png differ diff --git a/figures/older_figures/score.png b/figures/older_figures/score.png new file mode 100644 index 0000000..7e8f973 Binary files /dev/null and b/figures/older_figures/score.png differ diff --git a/figures/older_figures/total_loss.png b/figures/older_figures/total_loss.png new file mode 100644 index 0000000..2636dcb Binary files /dev/null and b/figures/older_figures/total_loss.png differ diff --git a/figures/score.png b/figures/score.png index a51cae3..b4ab992 100644 Binary files a/figures/score.png and b/figures/score.png differ diff --git a/figures/total_loss.png b/figures/total_loss.png index 4e5482d..41c1ce0 100644 Binary files a/figures/total_loss.png and b/figures/total_loss.png differ diff --git a/level.py b/level.py index e43c4fd..62e4ab1 100644 --- a/level.py +++ b/level.py @@ -91,16 +91,16 @@ class Level: elif col == '700' and self.n_players > 1: print(f"Prison set at:{(x, y)}") # Generate grass - if style == 'grass': - random_grass_image = choice(self.graphics['grass']) - - Terrain((x, y), [ - self.visible_sprites, - self.obstacle_sprites, - self.attackable_sprites - ], - 'grass', - random_grass_image) + # if style == 'grass': + # random_grass_image = choice(self.graphics['grass']) + # + # Terrain((x, y), [ + # self.visible_sprites, + # self.obstacle_sprites, + # self.attackable_sprites + # ], + # 'grass', + # random_grass_image) # Generate objects like trees and statues # if style == 'objects': @@ -171,18 +171,18 @@ class Level: if int(col) != -1: x = col_index * TILESIZE y = row_index * TILESIZE - # Regenerate grass - if style == 'grass': - random_grass_image = choice( - self.graphics['grass']) - - Terrain((x, y), [ - self.visible_sprites, - self.obstacle_sprites, - self.attackable_sprites - ], - 'grass', - random_grass_image) + # # Regenerate grass + # if style == 'grass': + # random_grass_image = choice( + # self.graphics['grass']) + # + # Terrain((x, y), [ + # self.visible_sprites, + # self.obstacle_sprites, + # self.attackable_sprites + # ], + # 'grass', + # random_grass_image) if style == 'entities': @@ -309,7 +309,11 @@ class Level: debug('PAUSED') for player in self.player_sprites: - self.dead_players[player.player_id] = player.is_dead() + if player.is_dead(): + print('Player dead') + player.stats.exp = -10 + player.update() + self.dead_players[player.player_id] = player.is_dead() self.done = True if (self.dead_players.all() == 1 or self.enemy_sprites == []) else False diff --git a/pneuma.py b/pneuma.py index 13c4b9e..18da4a6 100644 --- a/pneuma.py +++ b/pneuma.py @@ -73,6 +73,11 @@ if __name__ == "__main__": default=0.99, help="The gamma parameter for PPO") + parser.add_argument('--entropy', + type=float, + default=0.001, + help="The entropy coefficient") + parser.add_argument('--alpha', type=float, default=0.0003, @@ -119,6 +124,7 @@ if __name__ == "__main__": # Setup AI stuff score_history = np.zeros(shape=(n_players, n_episodes)) + best_score = np.zeros(n_players) actor_loss = np.zeros(shape=(n_players, @@ -142,6 +148,7 @@ if __name__ == "__main__": N=args.horizon, n_epochs=args.n_epochs, gae_lambda=args.gae_lambda, + entropy_coef=args.entropy, chkpt_dir=chkpt_path, no_load=args.no_load ) @@ -189,7 +196,7 @@ if __name__ == "__main__": # Gather information about the episode for player in game.level.player_sprites: - score = np.mean(player.reward_features) + score = player.reward # Update score score_history[player.player_id][episode] = np.mean(score) @@ -223,10 +230,10 @@ if __name__ == "__main__": print(f"Models saved to {chkpt_path}") plt.figure() - plt.title("Player Performance") + plt.title("Agent Rewards") plt.xlabel("Episode") plt.ylabel("Score") - plt.legend([f"Player {num}" for num in range(n_players)]) + plt.legend([f"Agent {num}" for num in range(n_players)]) for player_score in score_history: plt.plot(player_score) plt.savefig(f"{figure_folder}/score.png")