Added score_metric and loading options
This commit is contained in:
parent
aaaf7a2829
commit
7eb7228a8c
4 changed files with 13 additions and 34 deletions
7
args.py
7
args.py
|
@ -13,7 +13,7 @@ def parse_args():
|
|||
parser.add_argument('--no_seed',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Set to True to run without a seed.")
|
||||
help="Set to run without a seed.")
|
||||
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
|
@ -83,6 +83,11 @@ def parse_args():
|
|||
action="store_true",
|
||||
help="Set flag to disable learning. Useful for viewing trained agents interact in the environment.")
|
||||
|
||||
parser.add_argument('--load',
|
||||
type=int,
|
||||
default=None,
|
||||
help="Run id to load within chkpt_path")
|
||||
|
||||
parser.add_argument('--show_pg',
|
||||
default=False,
|
||||
action="store_true",
|
||||
|
|
|
@ -61,7 +61,7 @@ class Player(pygame.sprite.Sprite):
|
|||
gae_lambda,
|
||||
chkpt_dir,
|
||||
entropy_coef,
|
||||
no_load=False):
|
||||
load=None):
|
||||
|
||||
self.max_num_enemies = len(self.distance_direction_from_enemy)
|
||||
self.get_current_state()
|
||||
|
@ -82,12 +82,12 @@ class Player(pygame.sprite.Sprite):
|
|||
print(
|
||||
f"\nAgent initialized on player {self.player_id} using {self.agent.actor.device}.")
|
||||
|
||||
if not no_load:
|
||||
if load:
|
||||
print("Attempting to load models ...")
|
||||
try:
|
||||
self.agent.load_models(
|
||||
actr_chkpt=f"A{self.player_id}",
|
||||
crtc_chkpt=f"C{self.player_id}"
|
||||
actr_chkpt=f"run{load}/A{self.player_id}",
|
||||
crtc_chkpt=f"run{load}/C{self.player_id}"
|
||||
)
|
||||
print("Models loaded ...\n")
|
||||
|
||||
|
|
4
main.py
4
main.py
|
@ -80,7 +80,7 @@ def main():
|
|||
gae_lambda=parsed_args.gae_lambda,
|
||||
entropy_coef=parsed_args.entropy_coeff,
|
||||
chkpt_dir=chkpt_path,
|
||||
no_load=True
|
||||
load=parsed_args.load
|
||||
)
|
||||
|
||||
# Episodes start
|
||||
|
@ -173,6 +173,8 @@ def main():
|
|||
|
||||
metrics.plot_learning_curve(score_history, parsed_args.n_agents, figure_path)
|
||||
|
||||
metrics.plot_score(score_history, parsed_args.n_agents, figure_path)
|
||||
|
||||
metrics.plot_loss('actor', actor_loss, parsed_args.n_agents, figure_path)
|
||||
|
||||
metrics.plot_loss('critic', critic_loss, parsed_args.n_agents, figure_path)
|
||||
|
|
|
@ -3,34 +3,6 @@ import numpy as np
|
|||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def generate(parsed_args):
|
||||
|
||||
# Setup parameter monitoring
|
||||
score_history = np.zeros(
|
||||
shape=(parsed_args.n_agents, parsed_args.n_episodes))
|
||||
|
||||
best_score = np.zeros(parsed_args.n_agents)
|
||||
|
||||
actor_loss = np.zeros(shape=(parsed_args.n_agents,
|
||||
parsed_args.n_episodes))
|
||||
|
||||
critic_loss = np.zeros(shape=(parsed_args.n_agents,
|
||||
parsed_args.n_episodes))
|
||||
|
||||
total_loss = np.zeros(shape=(parsed_args.n_agents,
|
||||
parsed_args.n_episodes))
|
||||
|
||||
entropy = np.zeros(shape=(parsed_args.n_agents,
|
||||
parsed_args.n_episodes))
|
||||
|
||||
advantage = np.zeros(shape=(parsed_args.n_agents,
|
||||
parsed_args.n_episodes))
|
||||
|
||||
return score_history, best_score, actor_loss,
|
||||
critic_loss, total_loss, entropy,
|
||||
advantage
|
||||
|
||||
|
||||
def plot_learning_curve(scores, num_players, figure_path):
|
||||
|
||||
plt.figure()
|
||||
|
|
Loading…
Reference in a new issue