lyceum-env/classes/english/new_training.py

179 lines
4.6 KiB
Python
Raw Normal View History

import os
import time
import pickle
import numpy as np
from contextlib import nullcontext
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import mlflow
from transformer import Transformer
# from config import Config
# Default config values
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
output_dir = 'chkpts'
eval_interval = 2000
log_interval = 1
eval_iters = 200
eval_only = False
always_save_checkpoint = True
init_from = 'start'
# Data loading
# dataset = 'english'
# Hyperparams (These are currently based off of NanoGPT)
device = "cuda" if torch.cuda.is_available() else "cpu"
## Transformer Architecture
gradient_accumulation_steps = 5*8
batch_size = 12
block_size = 1024
n_head = 6
n_layer = 6
n_emb = 384
dropout = 0.1
bias = False
## Adam optim (Modified from default)
learning_rate = 3e-5
beta1 = beta2 = 0.9
grad_clip = 1.0
# Save params in dict for saving & MLFlow
model_args = dict(
bias=bias,
n_emb=n_emb,
beta1=beta1,
beta2=beta2,
n_head=n_head,
n_layer=n_layer,
dropout=dropout,
grad_clip=grad_size,
batch_size=batch_size,
block_size=block_size,
learning_rate=learning_rate,
)
# TODO: MLFlow logging
mlflow_log = True
mlflow.set_tracking_uri(uri="http://localhost:5000")
mlflow.set_experiment("Lyceum English Teacher")
mlflow.log_params(model_args)
# Estimate Loss
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
with ctx:
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
if mlflow_log:
mlflow.log_metric("loss", f"{loss:2f}", step=)
model.train()
return out
###### INIT MODEL ######
print(f'Initialising from {init_from}')
if init_from=='start':
if meta_vocab_size is None:
# TODO: Figure out vocab_size
print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
model_args['vocab_size'] = 50304
else:
model_args['vocab_size'] = meta_vocab_size
config = Config(**model_args)
model = Transformer(config)
elif init_from=='chkpt':
chkpt_path = os.path.join(output_dir, 'chkpt.pt')
checkpoint = torch.load(chkpt_path, map_location=device)
chkpt_args = checkpoint['model_args']
config = Config(**chkpt_args)
model = Transformer(config)
state_dict = checkpoint['model']
#Apparently this is an issue?
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
iter_num = checkpoint['iter_num']
best_val_loss = checkpoint['best_val_loss']
if block_size < model.config.block_size:
model.crop_block_size(block_size)
model.to(device)
###### OPTIMIZER ########
optimizer = model.configure_optim(learning_rate, (beta1, beta2), device)
if init_from == 'chkpt':
optimizer.load_state.dict(checkpoint(['optimizer']))
del checkpoint
# Compile PyTorch model (Requires PyTorch 2.0)
if compile:
print("compiling the model...")
unoptimized_model = model
model = torch.compile(model)
if mlflow_log:
import mlflow
X, Y = get_batch('train')
start_time = time.time()
local_iter_num = 0
running_mfu = -1.0
while True:
for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate
if iter_num % eval_interval == 0 and master_process:
losses = estimate_loss()
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
if mlflow_log:
mlflow.log()
if losses['val'] < best_val_loss or always_save_checkpoint:
best_val_loss = losses['val']
if iter_num > 0:
checkpoint = {
'config': config,
'iter_num': iter_num,
'model_args': model_args,
'best_val_loss': best_val_loss,
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
}
print(f"saving chkpt to '{output_dir}/chkpt.pt'")
torch.save(checkpoint, os.path.join(output_dir, 'chkpt.pt'))
if iter_num == 0 and eval_only:
break
for step in range(gradient_accumulation_steps):
with ctx:
logits, loss = model(X, Y)
loss = loss / gradient_accumulation_steps