pneuma-pygame/Godot/stable_baselines3_example.py
2024-05-17 01:16:20 +02:00

225 lines
7.9 KiB
Python

import argparse
import os
import pathlib
from typing import Callable
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from godot_rl.core.utils import can_import
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
# To download the env source and binary:
# 1. gdrl.env_from_hub -r edbeeching/godot_rl_BallChase
# 2. chmod +x examples/godot_rl_BallChase/bin/BallChase.x86_64
if can_import("ray"):
print("WARNING, stable baselines and ray[rllib] are not compatible")
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
"--env_path",
default=None,
type=str,
help="The Godot binary to use, do not include for in editor training",
)
parser.add_argument(
"--experiment_dir",
default="logs/sb3",
type=str,
help="The name of the experiment directory, in which the tensorboard logs and checkpoints (if enabled) are "
"getting stored.",
)
parser.add_argument(
"--experiment_name",
default="experiment",
type=str,
help="The name of the experiment, which will be displayed in tensorboard and "
"for checkpoint directory and name (if enabled).",
)
parser.add_argument("--seed", type=int, default=0, help="seed of the experiment")
parser.add_argument(
"--resume_model_path",
default=None,
type=str,
help="The path to a model file previously saved using --save_model_path or a checkpoint saved using "
"--save_checkpoints_frequency. Use this to resume training or infer from a saved model.",
)
parser.add_argument(
"--save_model_path",
default=None,
type=str,
help="The path to use for saving the trained sb3 model after training is complete. Saved model can be used later "
"to resume training. Extension will be set to .zip",
)
parser.add_argument(
"--save_checkpoint_frequency",
default=None,
type=int,
help=(
"If set, will save checkpoints every 'frequency' environment steps. "
"Requires a unique --experiment_name or --experiment_dir for each run. "
"Does not need --save_model_path to be set. "
),
)
parser.add_argument(
"--onnx_export_path",
default=None,
type=str,
help="If included, will export onnx file after training to the path specified.",
)
parser.add_argument(
"--timesteps",
default=1_000_000,
type=int,
help="The number of environment steps to train for, default is 1_000_000. If resuming from a saved model, "
"it will continue training for this amount of steps from the saved state without counting previously trained "
"steps",
)
parser.add_argument(
"--inference",
default=False,
action="store_true",
help="Instead of training, it will run inference on a loaded model for --timesteps steps. "
"Requires --resume_model_path to be set.",
)
parser.add_argument(
"--linear_lr_schedule",
default=False,
action="store_true",
help="Use a linear LR schedule for training. If set, learning rate will decrease until it reaches 0 at "
"--timesteps"
"value. Note: On resuming training, the schedule will reset. If disabled, constant LR will be used.",
)
parser.add_argument(
"--viz",
action="store_true",
help="If set, the simulation will be displayed in a window during training. Otherwise "
"training will run without rendering the simulation. This setting does not apply to in-editor training.",
default=False,
)
parser.add_argument("--speedup", default=1, type=int, help="Whether to speed up the physics in the env")
parser.add_argument(
"--n_parallel",
default=1,
type=int,
help="How many instances of the environment executable to " "launch - requires --env_path to be set if > 1.",
)
args, extras = parser.parse_known_args()
def handle_onnx_export():
# Enforce the extension of onnx and zip when saving model to avoid potential conflicts in case of same name
# and extension used for both
if args.onnx_export_path is not None:
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print("Exporting onnx to: " + os.path.abspath(path_onnx))
export_ppo_model_as_onnx(model, str(path_onnx))
def handle_model_save():
if args.save_model_path is not None:
zip_save_path = pathlib.Path(args.save_model_path).with_suffix(".zip")
print("Saving model to: " + os.path.abspath(zip_save_path))
model.save(zip_save_path)
def close_env():
try:
print("closing env")
env.close()
except Exception as e:
print("Exception while closing env: ", e)
path_checkpoint = os.path.join(args.experiment_dir, args.experiment_name + "_checkpoints")
abs_path_checkpoint = os.path.abspath(path_checkpoint)
# Prevent overwriting existing checkpoints when starting a new experiment if checkpoint saving is enabled
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint):
raise RuntimeError(
abs_path_checkpoint + " folder already exists. "
"Use a different --experiment_dir, or --experiment_name,"
"or if previous checkpoints are not needed anymore, "
"remove the folder containing the checkpoints. "
)
if args.inference and args.resume_model_path is None:
raise parser.error("Using --inference requires --resume_model_path to be set.")
if args.env_path is None and args.viz:
print("Info: Using --viz without --env_path set has no effect, in-editor training will always render.")
env = StableBaselinesGodotEnv(
env_path=args.env_path, show_window=args.viz, seed=args.seed, n_parallel=args.n_parallel, speedup=args.speedup
)
env = VecMonitor(env)
# LR schedule code snippet from:
# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule
def linear_schedule(initial_value: float) -> Callable[[float], float]:
"""
Linear learning rate schedule.
:param initial_value: Initial learning rate.
:return: schedule that computes
current learning rate depending on remaining progress
"""
def func(progress_remaining: float) -> float:
"""
Progress will decrease from 1 (beginning) to 0.
:param progress_remaining:
:return: current learning rate
"""
return progress_remaining * initial_value
return func
if args.resume_model_path is None:
learning_rate = 0.0003 if not args.linear_lr_schedule else linear_schedule(0.0003)
model: PPO = PPO(
"MultiInputPolicy",
env,
batch_size=128,
ent_coef=0.001,
verbose=2,
n_steps=32,
tensorboard_log=args.experiment_dir,
learning_rate=learning_rate,
)
else:
path_zip = pathlib.Path(args.resume_model_path)
print("Loading model: " + os.path.abspath(path_zip))
model = PPO.load(path_zip, env=env, tensorboard_log=args.experiment_dir)
if args.inference:
obs = env.reset()
for i in range(args.timesteps):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
else:
learn_arguments = dict(total_timesteps=args.timesteps, tb_log_name=args.experiment_name)
if args.save_checkpoint_frequency:
print("Checkpoint saving enabled. Checkpoints will be saved to: " + abs_path_checkpoint)
checkpoint_callback = CheckpointCallback(
save_freq=(args.save_checkpoint_frequency // env.num_envs),
save_path=path_checkpoint,
name_prefix=args.experiment_name,
)
learn_arguments["callback"] = checkpoint_callback
try:
model.learn(**learn_arguments)
except KeyboardInterrupt:
print(
"""Training interrupted by user. Will save if --save_model_path was
used and/or export if --onnx_export_path was used."""
)
close_env()
handle_onnx_export()
handle_model_save()