From 656d42f7a088a45f8540234fcbc2e68d8bdf8f1d Mon Sep 17 00:00:00 2001 From: Vasilis Valatsos Date: Sat, 25 May 2024 14:53:05 +0200 Subject: [PATCH] Improved formatting of main for remote --- agent.py | 6 +++--- main.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/agent.py b/agent.py index a7aa494..6006518 100644 --- a/agent.py +++ b/agent.py @@ -15,7 +15,7 @@ from godot_rl.core.utils import can_import from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv -def main(policy_name=None, policy=None, parseargs=None): +def main(policy_name=None, policy=None, parseargs=None, seed=None, speedup=None): if can_import("ray"): print("WARNING: SB3 and ray[rllib] are not compatible.") @@ -71,9 +71,9 @@ def main(policy_name=None, policy=None, parseargs=None): env = StableBaselinesGodotEnv( env_path=args.env_path, show_window=args.viz, - seed=args.seed, + seed=args.seed if seed is None else seed, n_parallel=args.n_parallel, - speedup=args.speedup + speedup=args.speedup if speedup is None else speedup ) env = VecMonitor(env) diff --git a/main.py b/main.py index 0d30b8b..06d898d 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,6 @@ from policy import policies from agent import main import args -for policy_name, policy in policies.items(): - if policy_name in ["policy_big_tanh", "policy_big_optim_tanh"]: - print(policy_name) - main(policy_name=policy_name, policy=policy, parseargs=args.parse_args()) +for i in range(4): + for policy_name, policy in policies.items(): + main(policy_name=policy_name, policy=policy, parseargs=args.parse_args(), speedup=50)