lyceum-env/classes/english/train_transformer.py

105 lines
3.2 KiB
Python
Raw Normal View History

2024-10-12 20:06:49 +00:00
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm
from transformer import Transformer
# Utils import
from hparams import batch_size, num_epochs, epoch_test_num
from tokenizer import tokenize, build_vocab, numericalize, stringify
output_dir = 'chkpts'
# TODO: Better tokenizer
tokens = tokenize("wizard.txt")
word_to_idx = build_vocab(tokens, max_vocab_size=50000)
numericalized_data = numericalize(tokens, word_to_idx)
src = torch.tensor(numericalized_data).unsqueeze(0)
trg = torch.tensor(numericalized_data).unsqueeze(0)
print(trg.shape[1])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
src_pad_idx = word_to_idx["<PAD>"]
trg_pad_idx = word_to_idx["<PAD>"]
src_vocab_size = len(word_to_idx)
trg_vocab_size = len(word_to_idx)
max_idx = max(numericalized_data)
print(f"Max index in numericalized_data: {max_idx}")
print(f"Vocabulary size: {src_vocab_size}")
if "<PAD>" not in word_to_idx:
print("Warning: <PAD> token is missing in the vocabulary!")
else:
print(f"<PAD> index: {word_to_idx['<PAD>']}")
model_args = dict()
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device, dropout=0.1).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx) # Ignore padding tokens in the loss calculation
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.9))
total_loss = []
best_val_loss = 1e5
for epoch in tqdm(range(num_epochs)):
model.train()
for i in range(0, len(numericalized_data) - 1, batch_size):
src_batch = src[:, i:i+batch_size].to(device)
trg_batch = trg[:, i:i+batch_size].to(device)
# Ensure the target input is one step ahead of the target output
trg_input = trg_batch[:, :-1] # Input to the decoder
trg_output = trg_batch[:, 1:].contiguous().view(-1) # Expected output
# Forward pass
optimizer.zero_grad()
output = model(src_batch, trg_input)
# Calculate loss and backpropagate
output = output.view(-1, output.shape[2])
loss = criterion(output, trg_output)
loss.backward()
# Update the model parameters
optimizer.step()
total_loss.append(loss.item())
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
if loss.item() < best_val_loss:
best_val_loss = loss.item()
if epoch > 0:
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': model_args,
'epoch': epoch,
'best_val_loss': best_val_loss,
# 'config': config
}
print(f"Saving checkpoint with loss {best_val_loss} at {output_dir}")
torch.save(checkpoint, os.path.join(output_dir, 'trnsfrm.pt'))
plt.plot(total_loss)
plt.savefig('total_loss.png')
model.eval()
with torch.no_grad():
out = model(src[:, :512].to(device), trg[:, :511].to(device))
idx_to_word = {idx: word for word, idx in word_to_idx.items()}
decoded_output = stringify(torch.argmax(out, dim=2)[0].cpu().numpy(), idx_to_word)
print(decoded_output)