Once again player with rewards, and added clipping to the params
|
@ -8,12 +8,13 @@ class Agent:
|
||||||
|
|
||||||
def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003,
|
def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003,
|
||||||
policy_clip=0.2, batch_size=64, N=2048, n_epochs=10,
|
policy_clip=0.2, batch_size=64, N=2048, n_epochs=10,
|
||||||
gae_lambda=0.95, chkpt_dir='tmp/ppo'):
|
gae_lambda=0.95, entropy_coef=0.001, chkpt_dir='tmp/ppo'):
|
||||||
|
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.policy_clip = policy_clip
|
self.policy_clip = policy_clip
|
||||||
self.n_epochs = n_epochs
|
self.n_epochs = n_epochs
|
||||||
self.gae_lambda = gae_lambda
|
self.gae_lambda = gae_lambda
|
||||||
|
self.entropy_coef = entropy_coef
|
||||||
|
|
||||||
self.actor = ActorNetwork(
|
self.actor = ActorNetwork(
|
||||||
input_dims, n_actions, alpha, chkpt_dir=chkpt_dir)
|
input_dims, n_actions, alpha, chkpt_dir=chkpt_dir)
|
||||||
|
@ -44,6 +45,8 @@ class Agent:
|
||||||
action = T.squeeze(action).item()
|
action = T.squeeze(action).item()
|
||||||
value = T.squeeze(value).item()
|
value = T.squeeze(value).item()
|
||||||
|
|
||||||
|
self.entropy = dist.entropy().mean().item()
|
||||||
|
|
||||||
return action, probs, value
|
return action, probs, value
|
||||||
|
|
||||||
def learn(self):
|
def learn(self):
|
||||||
|
@ -81,8 +84,10 @@ class Agent:
|
||||||
new_probs = dist.log_prob(actions)
|
new_probs = dist.log_prob(actions)
|
||||||
prob_ratio = new_probs.exp() / old_probs.exp()
|
prob_ratio = new_probs.exp() / old_probs.exp()
|
||||||
weighted_probs = advantage[batch] * prob_ratio
|
weighted_probs = advantage[batch] * prob_ratio
|
||||||
|
|
||||||
weighted_clipped_probs = T.clamp(
|
weighted_clipped_probs = T.clamp(
|
||||||
prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch]
|
prob_ratio, 1-self.policy_clip, 1+self.policy_clip)*advantage[batch]
|
||||||
|
|
||||||
self.actor_loss = -T.min(weighted_probs,
|
self.actor_loss = -T.min(weighted_probs,
|
||||||
weighted_clipped_probs).mean()
|
weighted_clipped_probs).mean()
|
||||||
|
|
||||||
|
@ -90,10 +95,27 @@ class Agent:
|
||||||
self.critic_loss = (returns - critic_value)**2
|
self.critic_loss = (returns - critic_value)**2
|
||||||
self.critic_loss = self.critic_loss.mean()
|
self.critic_loss = self.critic_loss.mean()
|
||||||
|
|
||||||
self.total_loss = self.actor_loss + 0.5*self.critic_loss
|
self.total_loss = self.actor_loss + 0.5 * \
|
||||||
|
self.critic_loss - self.entropy_coef*self.entropy
|
||||||
|
|
||||||
self.actor.optimizer.zero_grad()
|
self.actor.optimizer.zero_grad()
|
||||||
self.critic.optimizer.zero_grad()
|
self.critic.optimizer.zero_grad()
|
||||||
self.total_loss.backward()
|
self.total_loss.backward()
|
||||||
|
|
||||||
|
# Calculate the gradient norms for both networks
|
||||||
|
actor_grad_norm = T.nn.utils.clip_grad_norm_(
|
||||||
|
self.actor.parameters(), max_norm=1)
|
||||||
|
critic_grad_norm = T.nn.utils.clip_grad_norm_(
|
||||||
|
self.critic.parameters(), max_norm=1)
|
||||||
|
|
||||||
|
T.nn.utils.clip_grad_norm_(
|
||||||
|
self.actor.parameters(), max_norm=1)
|
||||||
|
T.nn.utils.clip_grad_norm_(
|
||||||
|
self.critic.parameters(), max_norm=1)
|
||||||
|
# Log or print the gradient norms
|
||||||
|
print(f"Actor Gradient Norm: {actor_grad_norm}")
|
||||||
|
print(f"Critic Gradient Norm: {critic_grad_norm}")
|
||||||
|
|
||||||
self.actor.optimizer.step()
|
self.actor.optimizer.step()
|
||||||
self.critic.optimizer.step()
|
self.critic.optimizer.step()
|
||||||
|
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
# AI setup
|
|
||||||
N = 20
|
|
||||||
batch_size = 5
|
|
||||||
n_epochs = 4
|
|
||||||
alpha = 0.0003
|
|
||||||
|
|
|
@ -5,40 +5,40 @@
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,392,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,392,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,500,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,500,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,390,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,393,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,390,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,391,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,400,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,393,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,392,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,392,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||||
|
|
|
|
@ -1,8 +1,8 @@
|
||||||
monster_data = {
|
monster_data = {
|
||||||
'squid': {'id': 1,
|
'squid': {'id': 1,
|
||||||
'health': .1,
|
'health': 100,
|
||||||
'exp': 1,
|
'exp': 10,
|
||||||
'attack': .5,
|
'attack': 50,
|
||||||
'attack_type': 'slash',
|
'attack_type': 'slash',
|
||||||
'speed': 3,
|
'speed': 3,
|
||||||
'knockback': 20,
|
'knockback': 20,
|
||||||
|
@ -10,19 +10,19 @@ monster_data = {
|
||||||
'notice_radius': 360},
|
'notice_radius': 360},
|
||||||
|
|
||||||
'raccoon': {'id': 2,
|
'raccoon': {'id': 2,
|
||||||
'health': .3,
|
'health': 300,
|
||||||
'exp': 2.5,
|
'exp': 25,
|
||||||
'attack': .8,
|
'attack': 80,
|
||||||
'attack_type': 'claw',
|
'attack_type': 'claw',
|
||||||
'speed': 2,
|
'speed': 2,
|
||||||
'knockback': 20,
|
'knockback': 10,
|
||||||
'attack_radius': 120,
|
'attack_radius': 120,
|
||||||
'notice_radius': 400},
|
'notice_radius': 400},
|
||||||
|
|
||||||
'spirit': {'id': 3,
|
'spirit': {'id': 3,
|
||||||
'health': .1,
|
'health': 80,
|
||||||
'exp': 1.1,
|
'exp': 11,
|
||||||
'attack': .6,
|
'attack': 60,
|
||||||
'attack_type': 'thunder',
|
'attack_type': 'thunder',
|
||||||
'speed': 4,
|
'speed': 4,
|
||||||
'knockback': 20,
|
'knockback': 20,
|
||||||
|
@ -30,9 +30,9 @@ monster_data = {
|
||||||
'notice_radius': 350},
|
'notice_radius': 350},
|
||||||
|
|
||||||
'bamboo': {'id': 4,
|
'bamboo': {'id': 4,
|
||||||
'health': .07,
|
'health': 70,
|
||||||
'exp': 1.2,
|
'exp': 9,
|
||||||
'attack': .2,
|
'attack': 20,
|
||||||
'attack_type': 'leaf_attack',
|
'attack_type': 'leaf_attack',
|
||||||
'speed': 3,
|
'speed': 3,
|
||||||
'knockback': 20,
|
'knockback': 20,
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
tank_stats = {
|
tank_stats = {
|
||||||
'role_id': 1,
|
'role_id': 1,
|
||||||
'health': 1.5,
|
'health': 150,
|
||||||
'energy': .4,
|
'energy': 40,
|
||||||
'attack': .7,
|
'attack': 10,
|
||||||
'magic': .3,
|
'magic': 3,
|
||||||
'speed': 3
|
'speed': 3
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ mage_stats = {
|
||||||
'role_id': 2,
|
'role_id': 2,
|
||||||
'health': 70,
|
'health': 70,
|
||||||
'energy': 80,
|
'energy': 80,
|
||||||
'attack': 3,
|
'attack': 7,
|
||||||
'magic': 6,
|
'magic': 6,
|
||||||
'speed': 5
|
'speed': 5
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ warrior_stats = {
|
||||||
'role_id': 3,
|
'role_id': 3,
|
||||||
'health': 100,
|
'health': 100,
|
||||||
'energy': 60,
|
'energy': 60,
|
||||||
'attack': 10,
|
'attack': 12,
|
||||||
'magic': 4,
|
'magic': 4,
|
||||||
'speed': 5
|
'speed': 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,6 +61,7 @@ class Player(pygame.sprite.Sprite):
|
||||||
n_epochs,
|
n_epochs,
|
||||||
gae_lambda,
|
gae_lambda,
|
||||||
chkpt_dir,
|
chkpt_dir,
|
||||||
|
entropy_coef,
|
||||||
no_load=False):
|
no_load=False):
|
||||||
|
|
||||||
self.max_num_enemies = len(self.distance_direction_from_enemy)
|
self.max_num_enemies = len(self.distance_direction_from_enemy)
|
||||||
|
@ -77,6 +78,7 @@ class Player(pygame.sprite.Sprite):
|
||||||
N=N,
|
N=N,
|
||||||
n_epochs=n_epochs,
|
n_epochs=n_epochs,
|
||||||
gae_lambda=gae_lambda,
|
gae_lambda=gae_lambda,
|
||||||
|
entropy_coef=entropy_coef,
|
||||||
chkpt_dir=chkpt_dir
|
chkpt_dir=chkpt_dir
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
|
@ -150,7 +152,7 @@ class Player(pygame.sprite.Sprite):
|
||||||
|
|
||||||
def fermi(x, a):
|
def fermi(x, a):
|
||||||
# Used for rescaling features
|
# Used for rescaling features
|
||||||
return 1 / (np.exp(-(x - a)) + 1)
|
return 1 / (np.exp((x - a)) + 1)
|
||||||
|
|
||||||
def maxwell(x, a):
|
def maxwell(x, a):
|
||||||
# Used for rescaling features
|
# Used for rescaling features
|
||||||
|
@ -166,27 +168,32 @@ class Player(pygame.sprite.Sprite):
|
||||||
|
|
||||||
self.action_features = [self._input.action]
|
self.action_features = [self._input.action]
|
||||||
|
|
||||||
self.reward_features = [
|
# self.reward = [
|
||||||
self.stats.exp,
|
# np.log(1 + self.stats.exp),
|
||||||
|
#
|
||||||
|
# fermi(nearest_dist, 50),
|
||||||
|
#
|
||||||
|
# fermi(
|
||||||
|
# nearest_enemy.stats.health,
|
||||||
|
# nearest_enemy.stats.monster_info['health']
|
||||||
|
# ),
|
||||||
|
#
|
||||||
|
# maxwell(
|
||||||
|
# len(self.distance_direction_from_enemy),
|
||||||
|
# self.max_num_enemies
|
||||||
|
# ) - 1,
|
||||||
|
#
|
||||||
|
# - fermi(
|
||||||
|
# self.stats.health,
|
||||||
|
# self.stats.stats['health']
|
||||||
|
# ),
|
||||||
|
# ]
|
||||||
|
|
||||||
fermi(nearest_dist, 10),
|
self.reward = self.stats.exp\
|
||||||
|
+ self.stats.health/self.stats.stats['health'] - 1\
|
||||||
fermi(
|
- nearest_dist/np.sqrt(np.sum(self.map_edge))\
|
||||||
nearest_enemy.stats.health,
|
- nearest_enemy.stats.health/nearest_enemy.stats.monster_info['health']\
|
||||||
nearest_enemy.stats.monster_info['health']
|
- len(self.distance_direction_from_enemy)/self.max_num_enemies
|
||||||
),
|
|
||||||
|
|
||||||
maxwell(
|
|
||||||
len(self.distance_direction_from_enemy),
|
|
||||||
self.max_num_enemies
|
|
||||||
) - 1,
|
|
||||||
|
|
||||||
- fermi(
|
|
||||||
self.stats.health,
|
|
||||||
self.stats.stats['health']
|
|
||||||
),
|
|
||||||
|
|
||||||
]
|
|
||||||
|
|
||||||
self.state_features = [
|
self.state_features = [
|
||||||
self.animation.rect.center[0]/self.map_edge[0],
|
self.animation.rect.center[0]/self.map_edge[0],
|
||||||
|
@ -204,7 +211,7 @@ class Player(pygame.sprite.Sprite):
|
||||||
for distance, direction, enemy in self.distance_direction_from_enemy:
|
for distance, direction, enemy in self.distance_direction_from_enemy:
|
||||||
enemy_states.extend([
|
enemy_states.extend([
|
||||||
|
|
||||||
fermi(distance, 10),
|
distance/np.sqrt(np.sum(self.map_edge)),
|
||||||
|
|
||||||
direction[0],
|
direction[0],
|
||||||
|
|
||||||
|
@ -256,17 +263,12 @@ class Player(pygame.sprite.Sprite):
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
|
||||||
if not self.is_dead():
|
|
||||||
|
|
||||||
self.agent_update()
|
self.agent_update()
|
||||||
|
|
||||||
# Cooldowns and Regen
|
# Cooldowns and Regen
|
||||||
self.stats.health_recovery()
|
self.stats.health_recovery()
|
||||||
self.stats.energy_recovery()
|
self.stats.energy_recovery()
|
||||||
|
|
||||||
else:
|
|
||||||
self.stats.exp = max(-1, self.stats.exp - .1)
|
|
||||||
|
|
||||||
# Refresh player based on input and animate
|
# Refresh player based on input and animate
|
||||||
self.get_status()
|
self.get_status()
|
||||||
self.animation.animate(
|
self.animation.animate(
|
||||||
|
|
Before Width: | Height: | Size: 51 KiB After Width: | Height: | Size: 13 KiB |
Before Width: | Height: | Size: 32 KiB After Width: | Height: | Size: 13 KiB |
BIN
figures/older_figures/actor_loss.png
Normal file
After Width: | Height: | Size: 24 KiB |
BIN
figures/older_figures/critic_loss.png
Normal file
After Width: | Height: | Size: 25 KiB |
BIN
figures/older_figures/score.png
Normal file
After Width: | Height: | Size: 55 KiB |
BIN
figures/older_figures/total_loss.png
Normal file
After Width: | Height: | Size: 24 KiB |
Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 15 KiB |
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 20 KiB |
48
level.py
|
@ -91,16 +91,16 @@ class Level:
|
||||||
elif col == '700' and self.n_players > 1:
|
elif col == '700' and self.n_players > 1:
|
||||||
print(f"Prison set at:{(x, y)}")
|
print(f"Prison set at:{(x, y)}")
|
||||||
# Generate grass
|
# Generate grass
|
||||||
if style == 'grass':
|
# if style == 'grass':
|
||||||
random_grass_image = choice(self.graphics['grass'])
|
# random_grass_image = choice(self.graphics['grass'])
|
||||||
|
#
|
||||||
Terrain((x, y), [
|
# Terrain((x, y), [
|
||||||
self.visible_sprites,
|
# self.visible_sprites,
|
||||||
self.obstacle_sprites,
|
# self.obstacle_sprites,
|
||||||
self.attackable_sprites
|
# self.attackable_sprites
|
||||||
],
|
# ],
|
||||||
'grass',
|
# 'grass',
|
||||||
random_grass_image)
|
# random_grass_image)
|
||||||
|
|
||||||
# Generate objects like trees and statues
|
# Generate objects like trees and statues
|
||||||
# if style == 'objects':
|
# if style == 'objects':
|
||||||
|
@ -171,18 +171,18 @@ class Level:
|
||||||
if int(col) != -1:
|
if int(col) != -1:
|
||||||
x = col_index * TILESIZE
|
x = col_index * TILESIZE
|
||||||
y = row_index * TILESIZE
|
y = row_index * TILESIZE
|
||||||
# Regenerate grass
|
# # Regenerate grass
|
||||||
if style == 'grass':
|
# if style == 'grass':
|
||||||
random_grass_image = choice(
|
# random_grass_image = choice(
|
||||||
self.graphics['grass'])
|
# self.graphics['grass'])
|
||||||
|
#
|
||||||
Terrain((x, y), [
|
# Terrain((x, y), [
|
||||||
self.visible_sprites,
|
# self.visible_sprites,
|
||||||
self.obstacle_sprites,
|
# self.obstacle_sprites,
|
||||||
self.attackable_sprites
|
# self.attackable_sprites
|
||||||
],
|
# ],
|
||||||
'grass',
|
# 'grass',
|
||||||
random_grass_image)
|
# random_grass_image)
|
||||||
|
|
||||||
if style == 'entities':
|
if style == 'entities':
|
||||||
|
|
||||||
|
@ -309,6 +309,10 @@ class Level:
|
||||||
debug('PAUSED')
|
debug('PAUSED')
|
||||||
|
|
||||||
for player in self.player_sprites:
|
for player in self.player_sprites:
|
||||||
|
if player.is_dead():
|
||||||
|
print('Player dead')
|
||||||
|
player.stats.exp = -10
|
||||||
|
player.update()
|
||||||
self.dead_players[player.player_id] = player.is_dead()
|
self.dead_players[player.player_id] = player.is_dead()
|
||||||
|
|
||||||
self.done = True if (self.dead_players.all() == 1
|
self.done = True if (self.dead_players.all() == 1
|
||||||
|
|
13
pneuma.py
|
@ -73,6 +73,11 @@ if __name__ == "__main__":
|
||||||
default=0.99,
|
default=0.99,
|
||||||
help="The gamma parameter for PPO")
|
help="The gamma parameter for PPO")
|
||||||
|
|
||||||
|
parser.add_argument('--entropy',
|
||||||
|
type=float,
|
||||||
|
default=0.001,
|
||||||
|
help="The entropy coefficient")
|
||||||
|
|
||||||
parser.add_argument('--alpha',
|
parser.add_argument('--alpha',
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0003,
|
default=0.0003,
|
||||||
|
@ -119,6 +124,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Setup AI stuff
|
# Setup AI stuff
|
||||||
score_history = np.zeros(shape=(n_players, n_episodes))
|
score_history = np.zeros(shape=(n_players, n_episodes))
|
||||||
|
|
||||||
best_score = np.zeros(n_players)
|
best_score = np.zeros(n_players)
|
||||||
|
|
||||||
actor_loss = np.zeros(shape=(n_players,
|
actor_loss = np.zeros(shape=(n_players,
|
||||||
|
@ -142,6 +148,7 @@ if __name__ == "__main__":
|
||||||
N=args.horizon,
|
N=args.horizon,
|
||||||
n_epochs=args.n_epochs,
|
n_epochs=args.n_epochs,
|
||||||
gae_lambda=args.gae_lambda,
|
gae_lambda=args.gae_lambda,
|
||||||
|
entropy_coef=args.entropy,
|
||||||
chkpt_dir=chkpt_path,
|
chkpt_dir=chkpt_path,
|
||||||
no_load=args.no_load
|
no_load=args.no_load
|
||||||
)
|
)
|
||||||
|
@ -189,7 +196,7 @@ if __name__ == "__main__":
|
||||||
# Gather information about the episode
|
# Gather information about the episode
|
||||||
for player in game.level.player_sprites:
|
for player in game.level.player_sprites:
|
||||||
|
|
||||||
score = np.mean(player.reward_features)
|
score = player.reward
|
||||||
|
|
||||||
# Update score
|
# Update score
|
||||||
score_history[player.player_id][episode] = np.mean(score)
|
score_history[player.player_id][episode] = np.mean(score)
|
||||||
|
@ -223,10 +230,10 @@ if __name__ == "__main__":
|
||||||
print(f"Models saved to {chkpt_path}")
|
print(f"Models saved to {chkpt_path}")
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.title("Player Performance")
|
plt.title("Agent Rewards")
|
||||||
plt.xlabel("Episode")
|
plt.xlabel("Episode")
|
||||||
plt.ylabel("Score")
|
plt.ylabel("Score")
|
||||||
plt.legend([f"Player {num}" for num in range(n_players)])
|
plt.legend([f"Agent {num}" for num in range(n_players)])
|
||||||
for player_score in score_history:
|
for player_score in score_history:
|
||||||
plt.plot(player_score)
|
plt.plot(player_score)
|
||||||
plt.savefig(f"{figure_folder}/score.png")
|
plt.savefig(f"{figure_folder}/score.png")
|
||||||
|
|