import os import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn from tqdm import tqdm from transformer import Transformer # Utils import from hparams import batch_size, num_epochs, epoch_test_num from tokenizer import tokenize, build_vocab, numericalize, stringify output_dir = 'chkpts' # TODO: Better tokenizer tokens = tokenize("wizard.txt") word_to_idx = build_vocab(tokens, max_vocab_size=50000) numericalized_data = numericalize(tokens, word_to_idx) src = torch.tensor(numericalized_data).unsqueeze(0) trg = torch.tensor(numericalized_data).unsqueeze(0) print(trg.shape[1]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") src_pad_idx = word_to_idx[""] trg_pad_idx = word_to_idx[""] src_vocab_size = len(word_to_idx) trg_vocab_size = len(word_to_idx) max_idx = max(numericalized_data) print(f"Max index in numericalized_data: {max_idx}") print(f"Vocabulary size: {src_vocab_size}") if "" not in word_to_idx: print("Warning: token is missing in the vocabulary!") else: print(f" index: {word_to_idx['']}") model_args = dict() model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device, dropout=0.1).to(device) criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx) # Ignore padding tokens in the loss calculation optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.9)) total_loss = [] best_val_loss = 1e5 for epoch in tqdm(range(num_epochs)): model.train() for i in range(0, len(numericalized_data) - 1, batch_size): src_batch = src[:, i:i+batch_size].to(device) trg_batch = trg[:, i:i+batch_size].to(device) # Ensure the target input is one step ahead of the target output trg_input = trg_batch[:, :-1] # Input to the decoder trg_output = trg_batch[:, 1:].contiguous().view(-1) # Expected output # Forward pass optimizer.zero_grad() output = model(src_batch, trg_input) # Calculate loss and backpropagate output = output.view(-1, output.shape[2]) loss = criterion(output, trg_output) loss.backward() # Update the model parameters optimizer.step() total_loss.append(loss.item()) print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}') if loss.item() < best_val_loss: best_val_loss = loss.item() if epoch > 0: checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': model_args, 'epoch': epoch, 'best_val_loss': best_val_loss, # 'config': config } print(f"Saving checkpoint with loss {best_val_loss} at {output_dir}") torch.save(checkpoint, os.path.join(output_dir, 'trnsfrm.pt')) plt.plot(total_loss) plt.savefig('total_loss.png') model.eval() with torch.no_grad(): out = model(src[:, :512].to(device), trg[:, :511].to(device)) idx_to_word = {idx: word for word, idx in word_to_idx.items()} decoded_output = stringify(torch.argmax(out, dim=2)[0].cpu().numpy(), idx_to_word) print(decoded_output)