Implemented multi policy testing
This commit is contained in:
parent
4788dd7783
commit
5f0d5484bf
14 changed files with 295 additions and 192 deletions
159
Godot/agent.py
Normal file
159
Godot/agent.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
import args
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import torch as T
|
||||
import torch.nn as nn
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback
|
||||
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
|
||||
|
||||
from godot_rl.core.utils import can_import
|
||||
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
|
||||
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
|
||||
|
||||
def main(policy_name=None, policy=None, parseargs=None):
|
||||
if can_import("ray"):
|
||||
print("WARNING: SB3 and ray[rllib] are not compatible.")
|
||||
|
||||
args, extras = parseargs
|
||||
# args, extras = args.parse_args()
|
||||
|
||||
def handle_onnx_export():
|
||||
'''
|
||||
Enforces the onnx and zip extentions when saving models.
|
||||
This avoids potential conflicts in case of identical names and extentions
|
||||
'''
|
||||
if args.onnx_export_path is not None:
|
||||
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
|
||||
print(f"Exporting onnx to: {os.path.abspath(path_onnx)}")
|
||||
export_ppo_model_as_onnx(model, str(path_onnx))
|
||||
|
||||
def handle_model_save():
|
||||
if args.save_model_path is not None:
|
||||
zip_save_path = pathlib.Path(Args.save_model_path).with_suffix(".zip")
|
||||
print(f"Saving model to: {os.path.abspath(zip_save_path)}")
|
||||
model.save(zip_save_path)
|
||||
|
||||
def close_env():
|
||||
try:
|
||||
print("Closing env...")
|
||||
env.close()
|
||||
except Exception as e:
|
||||
print(f"Exception while closing env: {e}")
|
||||
|
||||
if policy_name is None:
|
||||
path_checkpoint = os.path.join(args.exper_dir, f"{args.exper_name}_checkpoints")
|
||||
else:
|
||||
path_checkpoint = os.path.join(args.exper_dir, f"{policy_name}_checkpoints")
|
||||
|
||||
abs_path_checkpoint = os.path.abspath(path_checkpoint)
|
||||
|
||||
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint):
|
||||
raise RuntimeError(
|
||||
f"{abs_path_checkpoint} already exists."
|
||||
"Use a different directory or different name."
|
||||
"If you want to override previous checkpoints you have to delete them manually."
|
||||
)
|
||||
|
||||
if args.inference and args.resume_model_path is None:
|
||||
raise parser.error(
|
||||
"Using --inference requires --resume_model_path to be set."
|
||||
)
|
||||
|
||||
if args.env_path is None and args.viz:
|
||||
print("Info: using --viz without --env_path set has no effect.")
|
||||
print("\nIn editor training will always render.")
|
||||
|
||||
env = StableBaselinesGodotEnv(
|
||||
env_path=args.env_path,
|
||||
show_window=args.viz,
|
||||
seed=args.seed,
|
||||
n_parallel=args.n_parallel,
|
||||
speedup=args.speedup
|
||||
)
|
||||
env = VecMonitor(env)
|
||||
|
||||
# LR schedule code snippet from:
|
||||
# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule
|
||||
def linear_schedule(initial_value: float) -> Callable[[float], float]:
|
||||
"""
|
||||
Linear learning rate schedule.
|
||||
|
||||
:param initial_value: Initial learning rate.
|
||||
:return: schedule that computes
|
||||
current learning rate depending on remaining progress
|
||||
"""
|
||||
|
||||
def func(progress_remaining: float) -> float:
|
||||
"""
|
||||
Progress will decrease from 1 (beginning) to 0.
|
||||
|
||||
:param progress_remaining:
|
||||
:return: current learning rate
|
||||
"""
|
||||
return progress_remaining * initial_value
|
||||
|
||||
return func
|
||||
|
||||
if args.resume_model_path is None:
|
||||
if not args.linear_lr_schedule:
|
||||
learning_rate = 0.0003
|
||||
else:
|
||||
linear_schedule(0.0003)
|
||||
|
||||
model: PPO = PPO(
|
||||
# 'MultiInputPolicy' serves as an alias for MultiInputActorCriticPolicy
|
||||
"MultiInputPolicy",
|
||||
env,
|
||||
batch_size=64,
|
||||
ent_coef=0.01,
|
||||
verbose=2,
|
||||
n_steps=256,
|
||||
tensorboard_log=args.exper_dir,
|
||||
learning_rate=learning_rate,
|
||||
policy_kwargs=policy,
|
||||
)
|
||||
else:
|
||||
path_zip = pathlib.Path(args.resume_model_path)
|
||||
print(f"Loading model: {os.path.abspath(pathzip)}")
|
||||
model: PPO = PPO.load(
|
||||
path_zip,
|
||||
env=env,
|
||||
tensorboard_log=args.exper_dir
|
||||
)
|
||||
|
||||
if args.inference:
|
||||
obs = env.reset()
|
||||
for i in range(args.timesteps):
|
||||
action, _state = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
else:
|
||||
learn_arguments = dict(
|
||||
total_timesteps=args.timesteps,
|
||||
tb_log_name=policy_name
|
||||
)
|
||||
if args.save_checkpoint_frequency:
|
||||
print("Checkpoint saving enabled.")
|
||||
print(f"\nCheckpoints will be saved to {abs_path_checkpoint}")
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=(args.save_checkpoint_frequency // env.num_envs),
|
||||
save_path=path_checkpoint,
|
||||
name_prefix=policy_name
|
||||
)
|
||||
learn_arguments["callback"] = checkpoint_callback
|
||||
try:
|
||||
model.learn(**learn_arguments)
|
||||
except KeyboardInterrupt:
|
||||
print(
|
||||
"""
|
||||
Training interrupted by user. Will save if --save_model_path was set and/or export if --onnx_export was set.
|
||||
"""
|
||||
)
|
||||
|
||||
close_env()
|
||||
handle_onnx_export()
|
||||
handle_model_save()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
166
Godot/main.py
166
Godot/main.py
|
@ -1,164 +1,6 @@
|
|||
from policy import policies
|
||||
from agent import main
|
||||
import args
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import torch as T
|
||||
import torch.nn as nn
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback
|
||||
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
|
||||
|
||||
from godot_rl.core.utils import can_import
|
||||
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
|
||||
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
|
||||
|
||||
if can_import("ray"):
|
||||
print("WARNING: SB3 and ray[rllib] are not compatible.")
|
||||
|
||||
args, extras = args.parse_args()
|
||||
|
||||
def handle_onnx_export():
|
||||
'''
|
||||
Enforces the onnx and zip extentions when saving models.
|
||||
This avoids potential conflicts in case of identical names and extentions
|
||||
'''
|
||||
if args.onnx_export_path is not None:
|
||||
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
|
||||
print(f"Exporting onnx to: {os.path.abspath(path_onnx)}")
|
||||
export_ppo_model_as_onnx(model, str(path_onnx))
|
||||
|
||||
def handle_model_save():
|
||||
if args.save_model_path is not None:
|
||||
zip_save_path = pathlib.Path(Args.save_model_path).with_suffix(".zip")
|
||||
print(f"Saving model to: {os.path.abspath(zip_save_path)}")
|
||||
model.save(zip_save_path)
|
||||
|
||||
def close_env():
|
||||
try:
|
||||
print("Closing env...")
|
||||
env.close()
|
||||
except Exception as e:
|
||||
print(f"Exception while closing env: {e}")
|
||||
|
||||
path_checkpoint = os.path.join(args.exper_dir, f"{args.exper_name}_checkpoints")
|
||||
abs_path_checkpoint = os.path.abspath(path_checkpoint)
|
||||
|
||||
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint):
|
||||
raise RuntimeError(
|
||||
f"{abs_path_checkpoint} already exists."
|
||||
"Use a different directory or different name."
|
||||
"If you want to override previous checkpoints you have to delete them manually."
|
||||
)
|
||||
|
||||
if args.inference and args.resume_model_path is None:
|
||||
raise parser.error(
|
||||
"Using --inference requires --resume_model_path to be set."
|
||||
)
|
||||
|
||||
if args.env_path is None and args.viz:
|
||||
print("Info: using --viz without --env_path set has no effect.")
|
||||
print("\nIn editor training will always render.")
|
||||
|
||||
env = StableBaselinesGodotEnv(
|
||||
env_path=args.env_path,
|
||||
show_window=args.viz,
|
||||
seed=args.seed,
|
||||
n_parallel=args.n_parallel,
|
||||
speedup=args.speedup
|
||||
)
|
||||
env = VecMonitor(env)
|
||||
|
||||
# LR schedule code snippet from:
|
||||
# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule
|
||||
def linear_schedule(initial_value: float) -> Callable[[float], float]:
|
||||
"""
|
||||
Linear learning rate schedule.
|
||||
|
||||
:param initial_value: Initial learning rate.
|
||||
:return: schedule that computes
|
||||
current learning rate depending on remaining progress
|
||||
"""
|
||||
|
||||
def func(progress_remaining: float) -> float:
|
||||
"""
|
||||
Progress will decrease from 1 (beginning) to 0.
|
||||
|
||||
:param progress_remaining:
|
||||
:return: current learning rate
|
||||
"""
|
||||
return progress_remaining * initial_value
|
||||
|
||||
return func
|
||||
|
||||
policy_kwargs = dict(
|
||||
activation_fn=nn.Tanh,
|
||||
net_arch = dict(
|
||||
pi=[256],
|
||||
vf = [2048, 2048]#, 4096, 4096, 4096, 4096]
|
||||
),
|
||||
optimizer_class = T.optim.Adam,
|
||||
optimizer_kwargs = dict(
|
||||
betas=(0.9, 0.9),
|
||||
eps=1e-5
|
||||
),
|
||||
)
|
||||
|
||||
if args.resume_model_path is None:
|
||||
if not args.linear_lr_schedule:
|
||||
learning_rate = 0.0003
|
||||
else:
|
||||
linear_schedule(0.0003)
|
||||
|
||||
model: PPO = PPO(
|
||||
# 'MultiInputPolicy' serves as an alias for MultiInputActorCriticPolicy
|
||||
"MultiInputPolicy",
|
||||
env,
|
||||
batch_size=64,
|
||||
ent_coef=0.01,
|
||||
verbose=2,
|
||||
n_steps=256,
|
||||
tensorboard_log=args.exper_dir,
|
||||
learning_rate=learning_rate,
|
||||
policy_kwargs=policy_kwargs,
|
||||
# optimizer_kwargs=optimizer_kwargs,
|
||||
)
|
||||
else:
|
||||
path_zip = pathlib.Path(args.resume_model_path)
|
||||
print(f"Loading model: {os.path.abspath(pathzip)}")
|
||||
model: PPO = PPO.load(
|
||||
path_zip,
|
||||
env=env,
|
||||
tensorboard_log=args.exper_dir
|
||||
)
|
||||
|
||||
if args.inference:
|
||||
obs = env.reset()
|
||||
for i in range(args.timesteps):
|
||||
action, _state = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
else:
|
||||
learn_arguments = dict(total_timesteps=args.timesteps, tb_log_name=args.exper_name)
|
||||
if args.save_checkpoint_frequency:
|
||||
print("Checkpoint saving enabled.")
|
||||
print(f"\nCheckpoints will be saved to {abs_path_checkpoint}")
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=(args.save_checkpoint_frequency // env.num_envs),
|
||||
save_path=path_checkpoint,
|
||||
name_prefix=args.exper_name
|
||||
)
|
||||
learn_arguments["callback"] = checkpoint_callback
|
||||
try:
|
||||
model.learn(**learn_arguments)
|
||||
except KeyboardInterrupt:
|
||||
print(
|
||||
"""
|
||||
Training interrupted by user. Will save if --save_model_path was set and/or export if --onnx_export was set.
|
||||
"""
|
||||
)
|
||||
|
||||
close_env()
|
||||
handle_onnx_export()
|
||||
handle_model_save()
|
||||
for policy_name, policy in policies.items():
|
||||
main(policy_name=policy_name, policy=policy, parseargs=args.parse_args())
|
||||
|
|
153
Godot/policy.py
153
Godot/policy.py
|
@ -1,38 +1,131 @@
|
|||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
import torch as T
|
||||
import torch.nn as nn
|
||||
|
||||
from gymnasium import spaces
|
||||
import torch as th
|
||||
from torch import nn
|
||||
policy_small=dict(
|
||||
net_arch=dict(
|
||||
pi=[256],
|
||||
vf=[256]
|
||||
)
|
||||
)
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.policies import MultiInputPolicy
|
||||
policy_small_optim=dict(
|
||||
net_arch=dict(
|
||||
pi=[256],
|
||||
vf=[256]
|
||||
),
|
||||
optimizer_kwargs=dict(
|
||||
betas=(0.9, 0.9),
|
||||
eps=1e-5,
|
||||
),
|
||||
)
|
||||
|
||||
class CustomACPolicy(nn.Module):
|
||||
"""
|
||||
Custom network for policy and value functions.
|
||||
policy_small_tanh=dict(
|
||||
activation_fn=nn.Tanh,
|
||||
net_arch=dict(
|
||||
pi=[256],
|
||||
vf=[256]
|
||||
)
|
||||
)
|
||||
|
||||
It receives as input the number of layers for each network, the activation function and the optimizer parameters.
|
||||
"""
|
||||
policy_small_optim_tanh=dict(
|
||||
net_arch=dict(
|
||||
pi=[256],
|
||||
vf=[256]
|
||||
),
|
||||
optimizer_class=T.optim.Adam,
|
||||
optimizer_kwargs=dict(
|
||||
betas=(0.9, 0.9),
|
||||
eps=1e-5,
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_dim: int,
|
||||
last_layer_dim_pi: int = 64,
|
||||
last_layer_dim_vf: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
policy_mid=dict(
|
||||
net_arch=dict(
|
||||
pi=[512],
|
||||
vf=[2048, 2048]
|
||||
)
|
||||
)
|
||||
|
||||
self.latent_dim_pi = last_layer_dim_pi
|
||||
self.latent_dim_vf = last_layer_dim_vf
|
||||
policy_mid_tanh=dict(
|
||||
activation_fn=nn.Tanh,
|
||||
net_arch=dict(
|
||||
pi=[512],
|
||||
vf=[2048, 2048]
|
||||
)
|
||||
)
|
||||
|
||||
# Policy network
|
||||
self.policy_net = nn.Sequential(
|
||||
nn.Linear(feature_dim, last_layer_dim_pi),
|
||||
nn.Tanh()
|
||||
)
|
||||
policy_mid_optim=dict(
|
||||
net_arch=dict(
|
||||
pi=[512],
|
||||
vf=[2048, 2048]
|
||||
),
|
||||
optimizer_kwargs=dict(
|
||||
betas=(0.9,0.9),
|
||||
eps=1e-5
|
||||
)
|
||||
)
|
||||
|
||||
# Value network
|
||||
self.value_net = nn.Sequential(
|
||||
nn.Linear(feature_dim, last_layer_dim_vf),
|
||||
nn.Tanh()
|
||||
)
|
||||
policy_mid_optim_tanh=dict(
|
||||
activation_fn=nn.Tanh,
|
||||
net_arch=dict(
|
||||
pi=[512],
|
||||
vf=[2048, 2048]
|
||||
),
|
||||
optimizer_kwargs=dict(
|
||||
betas=(0.9,0.9),
|
||||
eps=1e-5
|
||||
)
|
||||
)
|
||||
|
||||
policy_big=dict(
|
||||
net_arch=dict(
|
||||
pi=[1024, 1024],
|
||||
vf=[4096, 4096, 4096, 4096]
|
||||
)
|
||||
)
|
||||
|
||||
policy_big_tanh=dict(
|
||||
activation_fn=nn.Tanh,
|
||||
net_arch=dict(
|
||||
pi=[1024, 1024],
|
||||
vf=[4096, 4096, 4096, 4096]
|
||||
)
|
||||
)
|
||||
|
||||
policy_big_optim=dict(
|
||||
net_arch=dict(
|
||||
pi=[1024, 1024],
|
||||
vf=[4096, 4096, 4096, 4096]
|
||||
),
|
||||
optimizer_kwargs=dict(
|
||||
betas=(0.9, 0.9),
|
||||
eps=1e-5,
|
||||
),
|
||||
)
|
||||
|
||||
policy_big_optim_tanh = dict(
|
||||
activation_fn=nn.Tanh,
|
||||
net_arch=dict(
|
||||
pi=[1024, 1024],
|
||||
vf=[4096, 4096, 4096, 4096],
|
||||
),
|
||||
optimizer_kwargs=dict(
|
||||
betas=(0.9, 0.9),
|
||||
eps=1e-5,
|
||||
),
|
||||
)
|
||||
|
||||
policies={
|
||||
"policy_small": policy_small,
|
||||
"policy_small_optim": policy_small_optim,
|
||||
"policy_small_tanh": policy_small_tanh,
|
||||
"policy_small_optim_tanh": policy_small_optim_tanh,
|
||||
"policy_mid": policy_mid,
|
||||
"policy_mid_optim": policy_mid_optim,
|
||||
"policy_mid_tanh": policy_mid_tanh,
|
||||
"policy_mid_optim_tanh": policy_mid_optim_tanh,
|
||||
"policy_big": policy_big,
|
||||
"policy_big_optim": policy_big_optim,
|
||||
"policy_big_tanh": policy_big_tanh,
|
||||
"policy_big_optim_tanh": policy_big_optim_tanh,
|
||||
}
|
||||
|
|
9
Godot/run_tests.sh
Executable file
9
Godot/run_tests.sh
Executable file
|
@ -0,0 +1,9 @@
|
|||
#!/bin/sh
|
||||
|
||||
python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_1" &&
|
||||
|
||||
python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_2" &&
|
||||
|
||||
python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_3" &&
|
||||
|
||||
python main.py --env_path="/home/valapeos/Projects/thesis/Godot/pneuma.x86_64" --speedup=200 --n_parallel=4 --exper_dir="logs/sb3_full_4"
|
Loading…
Reference in a new issue