diff --git a/Godot/agent.py b/Godot/agent.py new file mode 100644 index 0000000..a7aa494 --- /dev/null +++ b/Godot/agent.py @@ -0,0 +1,159 @@ +import args +import os +import pathlib + +import torch as T +import torch.nn as nn + +from typing import Callable + +from stable_baselines3 import PPO +from stable_baselines3.common.callbacks import CheckpointCallback +from stable_baselines3.common.vec_env.vec_monitor import VecMonitor + +from godot_rl.core.utils import can_import +from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx +from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv + +def main(policy_name=None, policy=None, parseargs=None): + if can_import("ray"): + print("WARNING: SB3 and ray[rllib] are not compatible.") + + args, extras = parseargs + # args, extras = args.parse_args() + + def handle_onnx_export(): + ''' + Enforces the onnx and zip extentions when saving models. + This avoids potential conflicts in case of identical names and extentions + ''' + if args.onnx_export_path is not None: + path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx") + print(f"Exporting onnx to: {os.path.abspath(path_onnx)}") + export_ppo_model_as_onnx(model, str(path_onnx)) + + def handle_model_save(): + if args.save_model_path is not None: + zip_save_path = pathlib.Path(Args.save_model_path).with_suffix(".zip") + print(f"Saving model to: {os.path.abspath(zip_save_path)}") + model.save(zip_save_path) + + def close_env(): + try: + print("Closing env...") + env.close() + except Exception as e: + print(f"Exception while closing env: {e}") + + if policy_name is None: + path_checkpoint = os.path.join(args.exper_dir, f"{args.exper_name}_checkpoints") + else: + path_checkpoint = os.path.join(args.exper_dir, f"{policy_name}_checkpoints") + + abs_path_checkpoint = os.path.abspath(path_checkpoint) + + if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint): + raise RuntimeError( + f"{abs_path_checkpoint} already exists." + "Use a different directory or different name." + "If you want to override previous checkpoints you have to delete them manually." + ) + + if args.inference and args.resume_model_path is None: + raise parser.error( + "Using --inference requires --resume_model_path to be set." + ) + + if args.env_path is None and args.viz: + print("Info: using --viz without --env_path set has no effect.") + print("\nIn editor training will always render.") + + env = StableBaselinesGodotEnv( + env_path=args.env_path, + show_window=args.viz, + seed=args.seed, + n_parallel=args.n_parallel, + speedup=args.speedup + ) + env = VecMonitor(env) + + # LR schedule code snippet from: + # https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule + def linear_schedule(initial_value: float) -> Callable[[float], float]: + """ + Linear learning rate schedule. + + :param initial_value: Initial learning rate. + :return: schedule that computes + current learning rate depending on remaining progress + """ + + def func(progress_remaining: float) -> float: + """ + Progress will decrease from 1 (beginning) to 0. + + :param progress_remaining: + :return: current learning rate + """ + return progress_remaining * initial_value + + return func + + if args.resume_model_path is None: + if not args.linear_lr_schedule: + learning_rate = 0.0003 + else: + linear_schedule(0.0003) + + model: PPO = PPO( + # 'MultiInputPolicy' serves as an alias for MultiInputActorCriticPolicy + "MultiInputPolicy", + env, + batch_size=64, + ent_coef=0.01, + verbose=2, + n_steps=256, + tensorboard_log=args.exper_dir, + learning_rate=learning_rate, + policy_kwargs=policy, + ) + else: + path_zip = pathlib.Path(args.resume_model_path) + print(f"Loading model: {os.path.abspath(pathzip)}") + model: PPO = PPO.load( + path_zip, + env=env, + tensorboard_log=args.exper_dir + ) + + if args.inference: + obs = env.reset() + for i in range(args.timesteps): + action, _state = model.predict(obs, deterministic=True) + obs, reward, done, info = env.step(action) + else: + learn_arguments = dict( + total_timesteps=args.timesteps, + tb_log_name=policy_name + ) + if args.save_checkpoint_frequency: + print("Checkpoint saving enabled.") + print(f"\nCheckpoints will be saved to {abs_path_checkpoint}") + checkpoint_callback = CheckpointCallback( + save_freq=(args.save_checkpoint_frequency // env.num_envs), + save_path=path_checkpoint, + name_prefix=policy_name + ) + learn_arguments["callback"] = checkpoint_callback + try: + model.learn(**learn_arguments) + except KeyboardInterrupt: + print( + """ + Training interrupted by user. Will save if --save_model_path was set and/or export if --onnx_export was set. + """ + ) + + close_env() + handle_onnx_export() + handle_model_save() diff --git a/Godot/logs/sb3/basic_sb3/events.out.tfevents.1716039274.valanixos.5296.0 b/Godot/logs/sb3/basic_sb3/events.out.tfevents.1716039274.valanixos.5296.0 new file mode 100644 index 0000000..36f9f80 Binary files /dev/null and b/Godot/logs/sb3/basic_sb3/events.out.tfevents.1716039274.valanixos.5296.0 differ diff --git a/Godot/logs/sb3/experiment_1/events.out.tfevents.1716072142.valanixos.132214.0 b/Godot/logs/sb3/experiment_1/events.out.tfevents.1716072142.valanixos.132214.0 new file mode 100644 index 0000000..f9ffd02 Binary files /dev/null and b/Godot/logs/sb3/experiment_1/events.out.tfevents.1716072142.valanixos.132214.0 differ diff --git a/Godot/logs/sb3/experiment_13/events.out.tfevents.1715954191.valanixos.15690.0 b/Godot/logs/sb3/experiment_13/events.out.tfevents.1715954191.valanixos.15690.0 deleted file mode 100644 index 8f3fb69..0000000 Binary files a/Godot/logs/sb3/experiment_13/events.out.tfevents.1715954191.valanixos.15690.0 and /dev/null differ diff --git a/Godot/logs/sb3/experiment_14/events.out.tfevents.1715970707.valanixos.41832.0 b/Godot/logs/sb3/experiment_14/events.out.tfevents.1715970707.valanixos.41832.0 deleted file mode 100644 index df1891e..0000000 Binary files a/Godot/logs/sb3/experiment_14/events.out.tfevents.1715970707.valanixos.41832.0 and /dev/null differ diff --git a/Godot/logs/sb3/experiment_15/events.out.tfevents.1715970870.valanixos.43900.0 b/Godot/logs/sb3/experiment_15/events.out.tfevents.1715970870.valanixos.43900.0 deleted file mode 100644 index f3a6e20..0000000 Binary files a/Godot/logs/sb3/experiment_15/events.out.tfevents.1715970870.valanixos.43900.0 and /dev/null differ diff --git a/Godot/logs/sb3/experiment_2/events.out.tfevents.1716141463.valanixos.1153942.0 b/Godot/logs/sb3/experiment_2/events.out.tfevents.1716141463.valanixos.1153942.0 new file mode 100644 index 0000000..339b7cc Binary files /dev/null and b/Godot/logs/sb3/experiment_2/events.out.tfevents.1716141463.valanixos.1153942.0 differ diff --git a/Godot/logs/sb3/experiment_9/events.out.tfevents.1715893254.valanixos.480596.0 b/Godot/logs/sb3/experiment_9/events.out.tfevents.1715893254.valanixos.480596.0 deleted file mode 100644 index 86be64e..0000000 Binary files a/Godot/logs/sb3/experiment_9/events.out.tfevents.1715893254.valanixos.480596.0 and /dev/null differ diff --git a/Godot/logs/sb3/policy_small_1/events.out.tfevents.1716193022.valanixos.22227.0 b/Godot/logs/sb3/policy_small_1/events.out.tfevents.1716193022.valanixos.22227.0 new file mode 100644 index 0000000..e9efcf9 Binary files /dev/null and b/Godot/logs/sb3/policy_small_1/events.out.tfevents.1716193022.valanixos.22227.0 differ diff --git a/Godot/logs/sb3/sb3_bigger_net/events.out.tfevents.1716047196.valanixos.40801.0 b/Godot/logs/sb3/sb3_bigger_net/events.out.tfevents.1716047196.valanixos.40801.0 new file mode 100644 index 0000000..54e8f5f Binary files /dev/null and b/Godot/logs/sb3/sb3_bigger_net/events.out.tfevents.1716047196.valanixos.40801.0 differ diff --git a/Godot/logs/sb3/sb3_bigger_net_tanh/events.out.tfevents.1716062748.valanixos.97391.0 b/Godot/logs/sb3/sb3_bigger_net_tanh/events.out.tfevents.1716062748.valanixos.97391.0 new file mode 100644 index 0000000..ef34e74 Binary files /dev/null and b/Godot/logs/sb3/sb3_bigger_net_tanh/events.out.tfevents.1716062748.valanixos.97391.0 differ diff --git a/Godot/main.py b/Godot/main.py index 9a0abb1..ead1e90 100644 --- a/Godot/main.py +++ b/Godot/main.py @@ -1,164 +1,6 @@ +from policy import policies +from agent import main import args -import os -import pathlib -import torch as T -import torch.nn as nn - -from typing import Callable - -from stable_baselines3 import PPO -from stable_baselines3.common.callbacks import CheckpointCallback -from stable_baselines3.common.vec_env.vec_monitor import VecMonitor - -from godot_rl.core.utils import can_import -from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx -from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv - -if can_import("ray"): - print("WARNING: SB3 and ray[rllib] are not compatible.") - -args, extras = args.parse_args() - -def handle_onnx_export(): - ''' - Enforces the onnx and zip extentions when saving models. - This avoids potential conflicts in case of identical names and extentions - ''' - if args.onnx_export_path is not None: - path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx") - print(f"Exporting onnx to: {os.path.abspath(path_onnx)}") - export_ppo_model_as_onnx(model, str(path_onnx)) - -def handle_model_save(): - if args.save_model_path is not None: - zip_save_path = pathlib.Path(Args.save_model_path).with_suffix(".zip") - print(f"Saving model to: {os.path.abspath(zip_save_path)}") - model.save(zip_save_path) - -def close_env(): - try: - print("Closing env...") - env.close() - except Exception as e: - print(f"Exception while closing env: {e}") - -path_checkpoint = os.path.join(args.exper_dir, f"{args.exper_name}_checkpoints") -abs_path_checkpoint = os.path.abspath(path_checkpoint) - -if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint): - raise RuntimeError( - f"{abs_path_checkpoint} already exists." - "Use a different directory or different name." - "If you want to override previous checkpoints you have to delete them manually." - ) - -if args.inference and args.resume_model_path is None: - raise parser.error( - "Using --inference requires --resume_model_path to be set." - ) - -if args.env_path is None and args.viz: - print("Info: using --viz without --env_path set has no effect.") - print("\nIn editor training will always render.") - -env = StableBaselinesGodotEnv( - env_path=args.env_path, - show_window=args.viz, - seed=args.seed, - n_parallel=args.n_parallel, - speedup=args.speedup -) -env = VecMonitor(env) - -# LR schedule code snippet from: -# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule -def linear_schedule(initial_value: float) -> Callable[[float], float]: - """ - Linear learning rate schedule. - - :param initial_value: Initial learning rate. - :return: schedule that computes - current learning rate depending on remaining progress - """ - - def func(progress_remaining: float) -> float: - """ - Progress will decrease from 1 (beginning) to 0. - - :param progress_remaining: - :return: current learning rate - """ - return progress_remaining * initial_value - - return func - -policy_kwargs = dict( - activation_fn=nn.Tanh, - net_arch = dict( - pi=[256], - vf = [2048, 2048]#, 4096, 4096, 4096, 4096] - ), - optimizer_class = T.optim.Adam, - optimizer_kwargs = dict( - betas=(0.9, 0.9), - eps=1e-5 - ), -) - -if args.resume_model_path is None: - if not args.linear_lr_schedule: - learning_rate = 0.0003 - else: - linear_schedule(0.0003) - - model: PPO = PPO( - # 'MultiInputPolicy' serves as an alias for MultiInputActorCriticPolicy - "MultiInputPolicy", - env, - batch_size=64, - ent_coef=0.01, - verbose=2, - n_steps=256, - tensorboard_log=args.exper_dir, - learning_rate=learning_rate, - policy_kwargs=policy_kwargs, - # optimizer_kwargs=optimizer_kwargs, - ) -else: - path_zip = pathlib.Path(args.resume_model_path) - print(f"Loading model: {os.path.abspath(pathzip)}") - model: PPO = PPO.load( - path_zip, - env=env, - tensorboard_log=args.exper_dir - ) - -if args.inference: - obs = env.reset() - for i in range(args.timesteps): - action, _state = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) -else: - learn_arguments = dict(total_timesteps=args.timesteps, tb_log_name=args.exper_name) - if args.save_checkpoint_frequency: - print("Checkpoint saving enabled.") - print(f"\nCheckpoints will be saved to {abs_path_checkpoint}") - checkpoint_callback = CheckpointCallback( - save_freq=(args.save_checkpoint_frequency // env.num_envs), - save_path=path_checkpoint, - name_prefix=args.exper_name - ) - learn_arguments["callback"] = checkpoint_callback - try: - model.learn(**learn_arguments) - except KeyboardInterrupt: - print( - """ - Training interrupted by user. Will save if --save_model_path was set and/or export if --onnx_export was set. - """ - ) - -close_env() -handle_onnx_export() -handle_model_save() +for policy_name, policy in policies.items(): + main(policy_name=policy_name, policy=policy, parseargs=args.parse_args()) diff --git a/Godot/policy.py b/Godot/policy.py index 6e5ceab..de39514 100644 --- a/Godot/policy.py +++ b/Godot/policy.py @@ -1,38 +1,131 @@ -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +import torch as T +import torch.nn as nn -from gymnasium import spaces -import torch as th -from torch import nn +policy_small=dict( + net_arch=dict( + pi=[256], + vf=[256] + ) +) -from stable_baselines3 import PPO -from stable_baselines3.common.policies import MultiInputPolicy +policy_small_optim=dict( + net_arch=dict( + pi=[256], + vf=[256] + ), + optimizer_kwargs=dict( + betas=(0.9, 0.9), + eps=1e-5, + ), +) -class CustomACPolicy(nn.Module): - """ - Custom network for policy and value functions. +policy_small_tanh=dict( + activation_fn=nn.Tanh, + net_arch=dict( + pi=[256], + vf=[256] + ) +) - It receives as input the number of layers for each network, the activation function and the optimizer parameters. - """ +policy_small_optim_tanh=dict( + net_arch=dict( + pi=[256], + vf=[256] + ), + optimizer_class=T.optim.Adam, + optimizer_kwargs=dict( + betas=(0.9, 0.9), + eps=1e-5, + ), +) - def __init__( - self, - feature_dim: int, - last_layer_dim_pi: int = 64, - last_layer_dim_vf: int = 64, - ): - super().__init__() +policy_mid=dict( + net_arch=dict( + pi=[512], + vf=[2048, 2048] + ) +) - self.latent_dim_pi = last_layer_dim_pi - self.latent_dim_vf = last_layer_dim_vf +policy_mid_tanh=dict( + activation_fn=nn.Tanh, + net_arch=dict( + pi=[512], + vf=[2048, 2048] + ) +) - # Policy network - self.policy_net = nn.Sequential( - nn.Linear(feature_dim, last_layer_dim_pi), - nn.Tanh() - ) +policy_mid_optim=dict( + net_arch=dict( + pi=[512], + vf=[2048, 2048] + ), + optimizer_kwargs=dict( + betas=(0.9,0.9), + eps=1e-5 + ) +) - # Value network - self.value_net = nn.Sequential( - nn.Linear(feature_dim, last_layer_dim_vf), - nn.Tanh() - ) +policy_mid_optim_tanh=dict( + activation_fn=nn.Tanh, + net_arch=dict( + pi=[512], + vf=[2048, 2048] + ), + optimizer_kwargs=dict( + betas=(0.9,0.9), + eps=1e-5 + ) +) + +policy_big=dict( + net_arch=dict( + pi=[1024, 1024], + vf=[4096, 4096, 4096, 4096] + ) +) + +policy_big_tanh=dict( + activation_fn=nn.Tanh, + net_arch=dict( + pi=[1024, 1024], + vf=[4096, 4096, 4096, 4096] + ) +) + +policy_big_optim=dict( + net_arch=dict( + pi=[1024, 1024], + vf=[4096, 4096, 4096, 4096] + ), + optimizer_kwargs=dict( + betas=(0.9, 0.9), + eps=1e-5, + ), +) + +policy_big_optim_tanh = dict( + activation_fn=nn.Tanh, + net_arch=dict( + pi=[1024, 1024], + vf=[4096, 4096, 4096, 4096], + ), + optimizer_kwargs=dict( + betas=(0.9, 0.9), + eps=1e-5, + ), +) + +policies={ + "policy_small": policy_small, + "policy_small_optim": policy_small_optim, + "policy_small_tanh": policy_small_tanh, + "policy_small_optim_tanh": policy_small_optim_tanh, + "policy_mid": policy_mid, + "policy_mid_optim": policy_mid_optim, + "policy_mid_tanh": policy_mid_tanh, + "policy_mid_optim_tanh": policy_mid_optim_tanh, + "policy_big": policy_big, + "policy_big_optim": policy_big_optim, + "policy_big_tanh": policy_big_tanh, + "policy_big_optim_tanh": policy_big_optim_tanh, +} diff --git a/Godot/run_tests.sh b/Godot/run_tests.sh new file mode 100755 index 0000000..e6597f2 --- /dev/null +++ b/Godot/run_tests.sh @@ -0,0 +1,9 @@ +#!/bin/sh + +python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_1" && + +python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_2" && + +python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_3" && + +python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_4"