pneuma-godot/agent.py
2024-05-22 08:45:43 +02:00

159 lines
5.4 KiB
Python

import args
import os
import pathlib
import torch as T
import torch.nn as nn
from typing import Callable
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from godot_rl.core.utils import can_import
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
def main(policy_name=None, policy=None, parseargs=None):
if can_import("ray"):
print("WARNING: SB3 and ray[rllib] are not compatible.")
args, extras = parseargs
# args, extras = args.parse_args()
def handle_onnx_export():
'''
Enforces the onnx and zip extentions when saving models.
This avoids potential conflicts in case of identical names and extentions
'''
if args.onnx_export_path is not None:
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print(f"Exporting onnx to: {os.path.abspath(path_onnx)}")
export_ppo_model_as_onnx(model, str(path_onnx))
def handle_model_save():
if args.save_model_path is not None:
zip_save_path = pathlib.Path(Args.save_model_path).with_suffix(".zip")
print(f"Saving model to: {os.path.abspath(zip_save_path)}")
model.save(zip_save_path)
def close_env():
try:
print("Closing env...")
env.close()
except Exception as e:
print(f"Exception while closing env: {e}")
if policy_name is None:
path_checkpoint = os.path.join(args.exper_dir, f"{args.exper_name}_checkpoints")
else:
path_checkpoint = os.path.join(args.exper_dir, f"{policy_name}_checkpoints")
abs_path_checkpoint = os.path.abspath(path_checkpoint)
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint):
raise RuntimeError(
f"{abs_path_checkpoint} already exists."
"Use a different directory or different name."
"If you want to override previous checkpoints you have to delete them manually."
)
if args.inference and args.resume_model_path is None:
raise parser.error(
"Using --inference requires --resume_model_path to be set."
)
if args.env_path is None and args.viz:
print("Info: using --viz without --env_path set has no effect.")
print("\nIn editor training will always render.")
env = StableBaselinesGodotEnv(
env_path=args.env_path,
show_window=args.viz,
seed=args.seed,
n_parallel=args.n_parallel,
speedup=args.speedup
)
env = VecMonitor(env)
# LR schedule code snippet from:
# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule
def linear_schedule(initial_value: float) -> Callable[[float], float]:
"""
Linear learning rate schedule.
:param initial_value: Initial learning rate.
:return: schedule that computes
current learning rate depending on remaining progress
"""
def func(progress_remaining: float) -> float:
"""
Progress will decrease from 1 (beginning) to 0.
:param progress_remaining:
:return: current learning rate
"""
return progress_remaining * initial_value
return func
if args.resume_model_path is None:
if not args.linear_lr_schedule:
learning_rate = 0.0003
else:
linear_schedule(0.0003)
model: PPO = PPO(
# 'MultiInputPolicy' serves as an alias for MultiInputActorCriticPolicy
"MultiInputPolicy",
env,
batch_size=64,
ent_coef=0.01,
verbose=2,
n_steps=256,
tensorboard_log=args.exper_dir,
learning_rate=learning_rate,
policy_kwargs=policy,
)
else:
path_zip = pathlib.Path(args.resume_model_path)
print(f"Loading model: {os.path.abspath(pathzip)}")
model: PPO = PPO.load(
path_zip,
env=env,
tensorboard_log=args.exper_dir
)
if args.inference:
obs = env.reset()
for i in range(args.timesteps):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
else:
learn_arguments = dict(
total_timesteps=args.timesteps,
tb_log_name=policy_name
)
if args.save_checkpoint_frequency:
print("Checkpoint saving enabled.")
print(f"\nCheckpoints will be saved to {abs_path_checkpoint}")
checkpoint_callback = CheckpointCallback(
save_freq=(args.save_checkpoint_frequency // env.num_envs),
save_path=path_checkpoint,
name_prefix=policy_name
)
learn_arguments["callback"] = checkpoint_callback
try:
model.learn(**learn_arguments)
except KeyboardInterrupt:
print(
"""
Training interrupted by user. Will save if --save_model_path was set and/or export if --onnx_export was set.
"""
)
close_env()
handle_onnx_export()
handle_model_save()