226 lines
7.9 KiB
Python
226 lines
7.9 KiB
Python
|
import argparse
|
||
|
import os
|
||
|
import pathlib
|
||
|
from typing import Callable
|
||
|
|
||
|
from stable_baselines3 import PPO
|
||
|
from stable_baselines3.common.callbacks import CheckpointCallback
|
||
|
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
|
||
|
|
||
|
from godot_rl.core.utils import can_import
|
||
|
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
|
||
|
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
|
||
|
|
||
|
# To download the env source and binary:
|
||
|
# 1. gdrl.env_from_hub -r edbeeching/godot_rl_BallChase
|
||
|
# 2. chmod +x examples/godot_rl_BallChase/bin/BallChase.x86_64
|
||
|
if can_import("ray"):
|
||
|
print("WARNING, stable baselines and ray[rllib] are not compatible")
|
||
|
|
||
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||
|
parser.add_argument(
|
||
|
"--env_path",
|
||
|
default=None,
|
||
|
type=str,
|
||
|
help="The Godot binary to use, do not include for in editor training",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--experiment_dir",
|
||
|
default="logs/sb3",
|
||
|
type=str,
|
||
|
help="The name of the experiment directory, in which the tensorboard logs and checkpoints (if enabled) are "
|
||
|
"getting stored.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--experiment_name",
|
||
|
default="experiment",
|
||
|
type=str,
|
||
|
help="The name of the experiment, which will be displayed in tensorboard and "
|
||
|
"for checkpoint directory and name (if enabled).",
|
||
|
)
|
||
|
parser.add_argument("--seed", type=int, default=0, help="seed of the experiment")
|
||
|
parser.add_argument(
|
||
|
"--resume_model_path",
|
||
|
default=None,
|
||
|
type=str,
|
||
|
help="The path to a model file previously saved using --save_model_path or a checkpoint saved using "
|
||
|
"--save_checkpoints_frequency. Use this to resume training or infer from a saved model.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--save_model_path",
|
||
|
default=None,
|
||
|
type=str,
|
||
|
help="The path to use for saving the trained sb3 model after training is complete. Saved model can be used later "
|
||
|
"to resume training. Extension will be set to .zip",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--save_checkpoint_frequency",
|
||
|
default=None,
|
||
|
type=int,
|
||
|
help=(
|
||
|
"If set, will save checkpoints every 'frequency' environment steps. "
|
||
|
"Requires a unique --experiment_name or --experiment_dir for each run. "
|
||
|
"Does not need --save_model_path to be set. "
|
||
|
),
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--onnx_export_path",
|
||
|
default=None,
|
||
|
type=str,
|
||
|
help="If included, will export onnx file after training to the path specified.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--timesteps",
|
||
|
default=1_000_000,
|
||
|
type=int,
|
||
|
help="The number of environment steps to train for, default is 1_000_000. If resuming from a saved model, "
|
||
|
"it will continue training for this amount of steps from the saved state without counting previously trained "
|
||
|
"steps",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--inference",
|
||
|
default=False,
|
||
|
action="store_true",
|
||
|
help="Instead of training, it will run inference on a loaded model for --timesteps steps. "
|
||
|
"Requires --resume_model_path to be set.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--linear_lr_schedule",
|
||
|
default=False,
|
||
|
action="store_true",
|
||
|
help="Use a linear LR schedule for training. If set, learning rate will decrease until it reaches 0 at "
|
||
|
"--timesteps"
|
||
|
"value. Note: On resuming training, the schedule will reset. If disabled, constant LR will be used.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--viz",
|
||
|
action="store_true",
|
||
|
help="If set, the simulation will be displayed in a window during training. Otherwise "
|
||
|
"training will run without rendering the simulation. This setting does not apply to in-editor training.",
|
||
|
default=False,
|
||
|
)
|
||
|
parser.add_argument("--speedup", default=1, type=int, help="Whether to speed up the physics in the env")
|
||
|
parser.add_argument(
|
||
|
"--n_parallel",
|
||
|
default=1,
|
||
|
type=int,
|
||
|
help="How many instances of the environment executable to " "launch - requires --env_path to be set if > 1.",
|
||
|
)
|
||
|
args, extras = parser.parse_known_args()
|
||
|
|
||
|
|
||
|
def handle_onnx_export():
|
||
|
# Enforce the extension of onnx and zip when saving model to avoid potential conflicts in case of same name
|
||
|
# and extension used for both
|
||
|
if args.onnx_export_path is not None:
|
||
|
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
|
||
|
print("Exporting onnx to: " + os.path.abspath(path_onnx))
|
||
|
export_ppo_model_as_onnx(model, str(path_onnx))
|
||
|
|
||
|
|
||
|
def handle_model_save():
|
||
|
if args.save_model_path is not None:
|
||
|
zip_save_path = pathlib.Path(args.save_model_path).with_suffix(".zip")
|
||
|
print("Saving model to: " + os.path.abspath(zip_save_path))
|
||
|
model.save(zip_save_path)
|
||
|
|
||
|
|
||
|
def close_env():
|
||
|
try:
|
||
|
print("closing env")
|
||
|
env.close()
|
||
|
except Exception as e:
|
||
|
print("Exception while closing env: ", e)
|
||
|
|
||
|
|
||
|
path_checkpoint = os.path.join(args.experiment_dir, args.experiment_name + "_checkpoints")
|
||
|
abs_path_checkpoint = os.path.abspath(path_checkpoint)
|
||
|
|
||
|
# Prevent overwriting existing checkpoints when starting a new experiment if checkpoint saving is enabled
|
||
|
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint):
|
||
|
raise RuntimeError(
|
||
|
abs_path_checkpoint + " folder already exists. "
|
||
|
"Use a different --experiment_dir, or --experiment_name,"
|
||
|
"or if previous checkpoints are not needed anymore, "
|
||
|
"remove the folder containing the checkpoints. "
|
||
|
)
|
||
|
|
||
|
if args.inference and args.resume_model_path is None:
|
||
|
raise parser.error("Using --inference requires --resume_model_path to be set.")
|
||
|
|
||
|
if args.env_path is None and args.viz:
|
||
|
print("Info: Using --viz without --env_path set has no effect, in-editor training will always render.")
|
||
|
|
||
|
env = StableBaselinesGodotEnv(
|
||
|
env_path=args.env_path, show_window=args.viz, seed=args.seed, n_parallel=args.n_parallel, speedup=args.speedup
|
||
|
)
|
||
|
env = VecMonitor(env)
|
||
|
|
||
|
|
||
|
# LR schedule code snippet from:
|
||
|
# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule
|
||
|
def linear_schedule(initial_value: float) -> Callable[[float], float]:
|
||
|
"""
|
||
|
Linear learning rate schedule.
|
||
|
|
||
|
:param initial_value: Initial learning rate.
|
||
|
:return: schedule that computes
|
||
|
current learning rate depending on remaining progress
|
||
|
"""
|
||
|
|
||
|
def func(progress_remaining: float) -> float:
|
||
|
"""
|
||
|
Progress will decrease from 1 (beginning) to 0.
|
||
|
|
||
|
:param progress_remaining:
|
||
|
:return: current learning rate
|
||
|
"""
|
||
|
return progress_remaining * initial_value
|
||
|
|
||
|
return func
|
||
|
|
||
|
|
||
|
if args.resume_model_path is None:
|
||
|
learning_rate = 0.0003 if not args.linear_lr_schedule else linear_schedule(0.0003)
|
||
|
model: PPO = PPO(
|
||
|
"MultiInputPolicy",
|
||
|
env,
|
||
|
batch_size=128,
|
||
|
ent_coef=0.001,
|
||
|
verbose=2,
|
||
|
n_steps=32,
|
||
|
tensorboard_log=args.experiment_dir,
|
||
|
learning_rate=learning_rate,
|
||
|
)
|
||
|
else:
|
||
|
path_zip = pathlib.Path(args.resume_model_path)
|
||
|
print("Loading model: " + os.path.abspath(path_zip))
|
||
|
model = PPO.load(path_zip, env=env, tensorboard_log=args.experiment_dir)
|
||
|
|
||
|
if args.inference:
|
||
|
obs = env.reset()
|
||
|
for i in range(args.timesteps):
|
||
|
action, _state = model.predict(obs, deterministic=True)
|
||
|
obs, reward, done, info = env.step(action)
|
||
|
else:
|
||
|
learn_arguments = dict(total_timesteps=args.timesteps, tb_log_name=args.experiment_name)
|
||
|
if args.save_checkpoint_frequency:
|
||
|
print("Checkpoint saving enabled. Checkpoints will be saved to: " + abs_path_checkpoint)
|
||
|
checkpoint_callback = CheckpointCallback(
|
||
|
save_freq=(args.save_checkpoint_frequency // env.num_envs),
|
||
|
save_path=path_checkpoint,
|
||
|
name_prefix=args.experiment_name,
|
||
|
)
|
||
|
learn_arguments["callback"] = checkpoint_callback
|
||
|
try:
|
||
|
model.learn(**learn_arguments)
|
||
|
except KeyboardInterrupt:
|
||
|
print(
|
||
|
"""Training interrupted by user. Will save if --save_model_path was
|
||
|
used and/or export if --onnx_export_path was used."""
|
||
|
)
|
||
|
|
||
|
close_env()
|
||
|
handle_onnx_export()
|
||
|
handle_model_save()
|