Implemented multi policy testing

This commit is contained in:
Vasilis Valatsos 2024-05-20 10:17:56 +02:00
parent 4788dd7783
commit 5f0d5484bf
14 changed files with 295 additions and 192 deletions

159
Godot/agent.py Normal file
View file

@ -0,0 +1,159 @@
import args
import os
import pathlib
import torch as T
import torch.nn as nn
from typing import Callable
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from godot_rl.core.utils import can_import
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
def main(policy_name=None, policy=None, parseargs=None):
if can_import("ray"):
print("WARNING: SB3 and ray[rllib] are not compatible.")
args, extras = parseargs
# args, extras = args.parse_args()
def handle_onnx_export():
'''
Enforces the onnx and zip extentions when saving models.
This avoids potential conflicts in case of identical names and extentions
'''
if args.onnx_export_path is not None:
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print(f"Exporting onnx to: {os.path.abspath(path_onnx)}")
export_ppo_model_as_onnx(model, str(path_onnx))
def handle_model_save():
if args.save_model_path is not None:
zip_save_path = pathlib.Path(Args.save_model_path).with_suffix(".zip")
print(f"Saving model to: {os.path.abspath(zip_save_path)}")
model.save(zip_save_path)
def close_env():
try:
print("Closing env...")
env.close()
except Exception as e:
print(f"Exception while closing env: {e}")
if policy_name is None:
path_checkpoint = os.path.join(args.exper_dir, f"{args.exper_name}_checkpoints")
else:
path_checkpoint = os.path.join(args.exper_dir, f"{policy_name}_checkpoints")
abs_path_checkpoint = os.path.abspath(path_checkpoint)
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint):
raise RuntimeError(
f"{abs_path_checkpoint} already exists."
"Use a different directory or different name."
"If you want to override previous checkpoints you have to delete them manually."
)
if args.inference and args.resume_model_path is None:
raise parser.error(
"Using --inference requires --resume_model_path to be set."
)
if args.env_path is None and args.viz:
print("Info: using --viz without --env_path set has no effect.")
print("\nIn editor training will always render.")
env = StableBaselinesGodotEnv(
env_path=args.env_path,
show_window=args.viz,
seed=args.seed,
n_parallel=args.n_parallel,
speedup=args.speedup
)
env = VecMonitor(env)
# LR schedule code snippet from:
# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule
def linear_schedule(initial_value: float) -> Callable[[float], float]:
"""
Linear learning rate schedule.
:param initial_value: Initial learning rate.
:return: schedule that computes
current learning rate depending on remaining progress
"""
def func(progress_remaining: float) -> float:
"""
Progress will decrease from 1 (beginning) to 0.
:param progress_remaining:
:return: current learning rate
"""
return progress_remaining * initial_value
return func
if args.resume_model_path is None:
if not args.linear_lr_schedule:
learning_rate = 0.0003
else:
linear_schedule(0.0003)
model: PPO = PPO(
# 'MultiInputPolicy' serves as an alias for MultiInputActorCriticPolicy
"MultiInputPolicy",
env,
batch_size=64,
ent_coef=0.01,
verbose=2,
n_steps=256,
tensorboard_log=args.exper_dir,
learning_rate=learning_rate,
policy_kwargs=policy,
)
else:
path_zip = pathlib.Path(args.resume_model_path)
print(f"Loading model: {os.path.abspath(pathzip)}")
model: PPO = PPO.load(
path_zip,
env=env,
tensorboard_log=args.exper_dir
)
if args.inference:
obs = env.reset()
for i in range(args.timesteps):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
else:
learn_arguments = dict(
total_timesteps=args.timesteps,
tb_log_name=policy_name
)
if args.save_checkpoint_frequency:
print("Checkpoint saving enabled.")
print(f"\nCheckpoints will be saved to {abs_path_checkpoint}")
checkpoint_callback = CheckpointCallback(
save_freq=(args.save_checkpoint_frequency // env.num_envs),
save_path=path_checkpoint,
name_prefix=policy_name
)
learn_arguments["callback"] = checkpoint_callback
try:
model.learn(**learn_arguments)
except KeyboardInterrupt:
print(
"""
Training interrupted by user. Will save if --save_model_path was set and/or export if --onnx_export was set.
"""
)
close_env()
handle_onnx_export()
handle_model_save()

View file

@ -1,164 +1,6 @@
from policy import policies
from agent import main
import args import args
import os
import pathlib
import torch as T for policy_name, policy in policies.items():
import torch.nn as nn main(policy_name=policy_name, policy=policy, parseargs=args.parse_args())
from typing import Callable
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from godot_rl.core.utils import can_import
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
if can_import("ray"):
print("WARNING: SB3 and ray[rllib] are not compatible.")
args, extras = args.parse_args()
def handle_onnx_export():
'''
Enforces the onnx and zip extentions when saving models.
This avoids potential conflicts in case of identical names and extentions
'''
if args.onnx_export_path is not None:
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print(f"Exporting onnx to: {os.path.abspath(path_onnx)}")
export_ppo_model_as_onnx(model, str(path_onnx))
def handle_model_save():
if args.save_model_path is not None:
zip_save_path = pathlib.Path(Args.save_model_path).with_suffix(".zip")
print(f"Saving model to: {os.path.abspath(zip_save_path)}")
model.save(zip_save_path)
def close_env():
try:
print("Closing env...")
env.close()
except Exception as e:
print(f"Exception while closing env: {e}")
path_checkpoint = os.path.join(args.exper_dir, f"{args.exper_name}_checkpoints")
abs_path_checkpoint = os.path.abspath(path_checkpoint)
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint):
raise RuntimeError(
f"{abs_path_checkpoint} already exists."
"Use a different directory or different name."
"If you want to override previous checkpoints you have to delete them manually."
)
if args.inference and args.resume_model_path is None:
raise parser.error(
"Using --inference requires --resume_model_path to be set."
)
if args.env_path is None and args.viz:
print("Info: using --viz without --env_path set has no effect.")
print("\nIn editor training will always render.")
env = StableBaselinesGodotEnv(
env_path=args.env_path,
show_window=args.viz,
seed=args.seed,
n_parallel=args.n_parallel,
speedup=args.speedup
)
env = VecMonitor(env)
# LR schedule code snippet from:
# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule
def linear_schedule(initial_value: float) -> Callable[[float], float]:
"""
Linear learning rate schedule.
:param initial_value: Initial learning rate.
:return: schedule that computes
current learning rate depending on remaining progress
"""
def func(progress_remaining: float) -> float:
"""
Progress will decrease from 1 (beginning) to 0.
:param progress_remaining:
:return: current learning rate
"""
return progress_remaining * initial_value
return func
policy_kwargs = dict(
activation_fn=nn.Tanh,
net_arch = dict(
pi=[256],
vf = [2048, 2048]#, 4096, 4096, 4096, 4096]
),
optimizer_class = T.optim.Adam,
optimizer_kwargs = dict(
betas=(0.9, 0.9),
eps=1e-5
),
)
if args.resume_model_path is None:
if not args.linear_lr_schedule:
learning_rate = 0.0003
else:
linear_schedule(0.0003)
model: PPO = PPO(
# 'MultiInputPolicy' serves as an alias for MultiInputActorCriticPolicy
"MultiInputPolicy",
env,
batch_size=64,
ent_coef=0.01,
verbose=2,
n_steps=256,
tensorboard_log=args.exper_dir,
learning_rate=learning_rate,
policy_kwargs=policy_kwargs,
# optimizer_kwargs=optimizer_kwargs,
)
else:
path_zip = pathlib.Path(args.resume_model_path)
print(f"Loading model: {os.path.abspath(pathzip)}")
model: PPO = PPO.load(
path_zip,
env=env,
tensorboard_log=args.exper_dir
)
if args.inference:
obs = env.reset()
for i in range(args.timesteps):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
else:
learn_arguments = dict(total_timesteps=args.timesteps, tb_log_name=args.exper_name)
if args.save_checkpoint_frequency:
print("Checkpoint saving enabled.")
print(f"\nCheckpoints will be saved to {abs_path_checkpoint}")
checkpoint_callback = CheckpointCallback(
save_freq=(args.save_checkpoint_frequency // env.num_envs),
save_path=path_checkpoint,
name_prefix=args.exper_name
)
learn_arguments["callback"] = checkpoint_callback
try:
model.learn(**learn_arguments)
except KeyboardInterrupt:
print(
"""
Training interrupted by user. Will save if --save_model_path was set and/or export if --onnx_export was set.
"""
)
close_env()
handle_onnx_export()
handle_model_save()

View file

@ -1,38 +1,131 @@
from typing import Callable, Dict, List, Optional, Tuple, Type, Union import torch as T
import torch.nn as nn
from gymnasium import spaces policy_small=dict(
import torch as th net_arch=dict(
from torch import nn pi=[256],
vf=[256]
from stable_baselines3 import PPO
from stable_baselines3.common.policies import MultiInputPolicy
class CustomACPolicy(nn.Module):
"""
Custom network for policy and value functions.
It receives as input the number of layers for each network, the activation function and the optimizer parameters.
"""
def __init__(
self,
feature_dim: int,
last_layer_dim_pi: int = 64,
last_layer_dim_vf: int = 64,
):
super().__init__()
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Policy network
self.policy_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_pi),
nn.Tanh()
) )
)
# Value network policy_small_optim=dict(
self.value_net = nn.Sequential( net_arch=dict(
nn.Linear(feature_dim, last_layer_dim_vf), pi=[256],
nn.Tanh() vf=[256]
),
optimizer_kwargs=dict(
betas=(0.9, 0.9),
eps=1e-5,
),
)
policy_small_tanh=dict(
activation_fn=nn.Tanh,
net_arch=dict(
pi=[256],
vf=[256]
) )
)
policy_small_optim_tanh=dict(
net_arch=dict(
pi=[256],
vf=[256]
),
optimizer_class=T.optim.Adam,
optimizer_kwargs=dict(
betas=(0.9, 0.9),
eps=1e-5,
),
)
policy_mid=dict(
net_arch=dict(
pi=[512],
vf=[2048, 2048]
)
)
policy_mid_tanh=dict(
activation_fn=nn.Tanh,
net_arch=dict(
pi=[512],
vf=[2048, 2048]
)
)
policy_mid_optim=dict(
net_arch=dict(
pi=[512],
vf=[2048, 2048]
),
optimizer_kwargs=dict(
betas=(0.9,0.9),
eps=1e-5
)
)
policy_mid_optim_tanh=dict(
activation_fn=nn.Tanh,
net_arch=dict(
pi=[512],
vf=[2048, 2048]
),
optimizer_kwargs=dict(
betas=(0.9,0.9),
eps=1e-5
)
)
policy_big=dict(
net_arch=dict(
pi=[1024, 1024],
vf=[4096, 4096, 4096, 4096]
)
)
policy_big_tanh=dict(
activation_fn=nn.Tanh,
net_arch=dict(
pi=[1024, 1024],
vf=[4096, 4096, 4096, 4096]
)
)
policy_big_optim=dict(
net_arch=dict(
pi=[1024, 1024],
vf=[4096, 4096, 4096, 4096]
),
optimizer_kwargs=dict(
betas=(0.9, 0.9),
eps=1e-5,
),
)
policy_big_optim_tanh = dict(
activation_fn=nn.Tanh,
net_arch=dict(
pi=[1024, 1024],
vf=[4096, 4096, 4096, 4096],
),
optimizer_kwargs=dict(
betas=(0.9, 0.9),
eps=1e-5,
),
)
policies={
"policy_small": policy_small,
"policy_small_optim": policy_small_optim,
"policy_small_tanh": policy_small_tanh,
"policy_small_optim_tanh": policy_small_optim_tanh,
"policy_mid": policy_mid,
"policy_mid_optim": policy_mid_optim,
"policy_mid_tanh": policy_mid_tanh,
"policy_mid_optim_tanh": policy_mid_optim_tanh,
"policy_big": policy_big,
"policy_big_optim": policy_big_optim,
"policy_big_tanh": policy_big_tanh,
"policy_big_optim_tanh": policy_big_optim_tanh,
}

9
Godot/run_tests.sh Executable file
View file

@ -0,0 +1,9 @@
#!/bin/sh
python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_1" &&
python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_2" &&
python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_3" &&
python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_4"