160 lines
5.4 KiB
Python
160 lines
5.4 KiB
Python
|
import args
|
||
|
import os
|
||
|
import pathlib
|
||
|
|
||
|
import torch as T
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from typing import Callable
|
||
|
|
||
|
from stable_baselines3 import PPO
|
||
|
from stable_baselines3.common.callbacks import CheckpointCallback
|
||
|
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
|
||
|
|
||
|
from godot_rl.core.utils import can_import
|
||
|
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
|
||
|
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
|
||
|
|
||
|
def main(policy_name=None, policy=None, parseargs=None):
|
||
|
if can_import("ray"):
|
||
|
print("WARNING: SB3 and ray[rllib] are not compatible.")
|
||
|
|
||
|
args, extras = parseargs
|
||
|
# args, extras = args.parse_args()
|
||
|
|
||
|
def handle_onnx_export():
|
||
|
'''
|
||
|
Enforces the onnx and zip extentions when saving models.
|
||
|
This avoids potential conflicts in case of identical names and extentions
|
||
|
'''
|
||
|
if args.onnx_export_path is not None:
|
||
|
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
|
||
|
print(f"Exporting onnx to: {os.path.abspath(path_onnx)}")
|
||
|
export_ppo_model_as_onnx(model, str(path_onnx))
|
||
|
|
||
|
def handle_model_save():
|
||
|
if args.save_model_path is not None:
|
||
|
zip_save_path = pathlib.Path(Args.save_model_path).with_suffix(".zip")
|
||
|
print(f"Saving model to: {os.path.abspath(zip_save_path)}")
|
||
|
model.save(zip_save_path)
|
||
|
|
||
|
def close_env():
|
||
|
try:
|
||
|
print("Closing env...")
|
||
|
env.close()
|
||
|
except Exception as e:
|
||
|
print(f"Exception while closing env: {e}")
|
||
|
|
||
|
if policy_name is None:
|
||
|
path_checkpoint = os.path.join(args.exper_dir, f"{args.exper_name}_checkpoints")
|
||
|
else:
|
||
|
path_checkpoint = os.path.join(args.exper_dir, f"{policy_name}_checkpoints")
|
||
|
|
||
|
abs_path_checkpoint = os.path.abspath(path_checkpoint)
|
||
|
|
||
|
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint):
|
||
|
raise RuntimeError(
|
||
|
f"{abs_path_checkpoint} already exists."
|
||
|
"Use a different directory or different name."
|
||
|
"If you want to override previous checkpoints you have to delete them manually."
|
||
|
)
|
||
|
|
||
|
if args.inference and args.resume_model_path is None:
|
||
|
raise parser.error(
|
||
|
"Using --inference requires --resume_model_path to be set."
|
||
|
)
|
||
|
|
||
|
if args.env_path is None and args.viz:
|
||
|
print("Info: using --viz without --env_path set has no effect.")
|
||
|
print("\nIn editor training will always render.")
|
||
|
|
||
|
env = StableBaselinesGodotEnv(
|
||
|
env_path=args.env_path,
|
||
|
show_window=args.viz,
|
||
|
seed=args.seed,
|
||
|
n_parallel=args.n_parallel,
|
||
|
speedup=args.speedup
|
||
|
)
|
||
|
env = VecMonitor(env)
|
||
|
|
||
|
# LR schedule code snippet from:
|
||
|
# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule
|
||
|
def linear_schedule(initial_value: float) -> Callable[[float], float]:
|
||
|
"""
|
||
|
Linear learning rate schedule.
|
||
|
|
||
|
:param initial_value: Initial learning rate.
|
||
|
:return: schedule that computes
|
||
|
current learning rate depending on remaining progress
|
||
|
"""
|
||
|
|
||
|
def func(progress_remaining: float) -> float:
|
||
|
"""
|
||
|
Progress will decrease from 1 (beginning) to 0.
|
||
|
|
||
|
:param progress_remaining:
|
||
|
:return: current learning rate
|
||
|
"""
|
||
|
return progress_remaining * initial_value
|
||
|
|
||
|
return func
|
||
|
|
||
|
if args.resume_model_path is None:
|
||
|
if not args.linear_lr_schedule:
|
||
|
learning_rate = 0.0003
|
||
|
else:
|
||
|
linear_schedule(0.0003)
|
||
|
|
||
|
model: PPO = PPO(
|
||
|
# 'MultiInputPolicy' serves as an alias for MultiInputActorCriticPolicy
|
||
|
"MultiInputPolicy",
|
||
|
env,
|
||
|
batch_size=64,
|
||
|
ent_coef=0.01,
|
||
|
verbose=2,
|
||
|
n_steps=256,
|
||
|
tensorboard_log=args.exper_dir,
|
||
|
learning_rate=learning_rate,
|
||
|
policy_kwargs=policy,
|
||
|
)
|
||
|
else:
|
||
|
path_zip = pathlib.Path(args.resume_model_path)
|
||
|
print(f"Loading model: {os.path.abspath(pathzip)}")
|
||
|
model: PPO = PPO.load(
|
||
|
path_zip,
|
||
|
env=env,
|
||
|
tensorboard_log=args.exper_dir
|
||
|
)
|
||
|
|
||
|
if args.inference:
|
||
|
obs = env.reset()
|
||
|
for i in range(args.timesteps):
|
||
|
action, _state = model.predict(obs, deterministic=True)
|
||
|
obs, reward, done, info = env.step(action)
|
||
|
else:
|
||
|
learn_arguments = dict(
|
||
|
total_timesteps=args.timesteps,
|
||
|
tb_log_name=policy_name
|
||
|
)
|
||
|
if args.save_checkpoint_frequency:
|
||
|
print("Checkpoint saving enabled.")
|
||
|
print(f"\nCheckpoints will be saved to {abs_path_checkpoint}")
|
||
|
checkpoint_callback = CheckpointCallback(
|
||
|
save_freq=(args.save_checkpoint_frequency // env.num_envs),
|
||
|
save_path=path_checkpoint,
|
||
|
name_prefix=policy_name
|
||
|
)
|
||
|
learn_arguments["callback"] = checkpoint_callback
|
||
|
try:
|
||
|
model.learn(**learn_arguments)
|
||
|
except KeyboardInterrupt:
|
||
|
print(
|
||
|
"""
|
||
|
Training interrupted by user. Will save if --save_model_path was set and/or export if --onnx_export was set.
|
||
|
"""
|
||
|
)
|
||
|
|
||
|
close_env()
|
||
|
handle_onnx_export()
|
||
|
handle_model_save()
|