104 lines
3.2 KiB
Python
104 lines
3.2 KiB
Python
import os
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from tqdm import tqdm
|
|
|
|
from transformer import Transformer
|
|
|
|
# Utils import
|
|
from hparams import batch_size, num_epochs, epoch_test_num
|
|
from tokenizer import tokenize, build_vocab, numericalize, stringify
|
|
|
|
output_dir = 'chkpts'
|
|
|
|
# TODO: Better tokenizer
|
|
tokens = tokenize("wizard.txt")
|
|
word_to_idx = build_vocab(tokens, max_vocab_size=50000)
|
|
numericalized_data = numericalize(tokens, word_to_idx)
|
|
|
|
src = torch.tensor(numericalized_data).unsqueeze(0)
|
|
trg = torch.tensor(numericalized_data).unsqueeze(0)
|
|
|
|
print(trg.shape[1])
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
src_pad_idx = word_to_idx["<PAD>"]
|
|
trg_pad_idx = word_to_idx["<PAD>"]
|
|
src_vocab_size = len(word_to_idx)
|
|
trg_vocab_size = len(word_to_idx)
|
|
|
|
max_idx = max(numericalized_data)
|
|
print(f"Max index in numericalized_data: {max_idx}")
|
|
print(f"Vocabulary size: {src_vocab_size}")
|
|
|
|
if "<PAD>" not in word_to_idx:
|
|
print("Warning: <PAD> token is missing in the vocabulary!")
|
|
else:
|
|
print(f"<PAD> index: {word_to_idx['<PAD>']}")
|
|
|
|
model_args = dict()
|
|
|
|
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device, dropout=0.1).to(device)
|
|
|
|
criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx) # Ignore padding tokens in the loss calculation
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.9))
|
|
|
|
total_loss = []
|
|
best_val_loss = 1e5
|
|
|
|
for epoch in tqdm(range(num_epochs)):
|
|
model.train()
|
|
for i in range(0, len(numericalized_data) - 1, batch_size):
|
|
src_batch = src[:, i:i+batch_size].to(device)
|
|
trg_batch = trg[:, i:i+batch_size].to(device)
|
|
|
|
# Ensure the target input is one step ahead of the target output
|
|
trg_input = trg_batch[:, :-1] # Input to the decoder
|
|
trg_output = trg_batch[:, 1:].contiguous().view(-1) # Expected output
|
|
|
|
# Forward pass
|
|
optimizer.zero_grad()
|
|
output = model(src_batch, trg_input)
|
|
|
|
# Calculate loss and backpropagate
|
|
output = output.view(-1, output.shape[2])
|
|
loss = criterion(output, trg_output)
|
|
loss.backward()
|
|
|
|
# Update the model parameters
|
|
optimizer.step()
|
|
|
|
total_loss.append(loss.item())
|
|
|
|
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
|
|
|
|
if loss.item() < best_val_loss:
|
|
best_val_loss = loss.item()
|
|
if epoch > 0:
|
|
checkpoint = {
|
|
'model': model.state_dict(),
|
|
'optimizer': optimizer.state_dict(),
|
|
'model_args': model_args,
|
|
'epoch': epoch,
|
|
'best_val_loss': best_val_loss,
|
|
# 'config': config
|
|
}
|
|
print(f"Saving checkpoint with loss {best_val_loss} at {output_dir}")
|
|
torch.save(checkpoint, os.path.join(output_dir, 'trnsfrm.pt'))
|
|
|
|
plt.plot(total_loss)
|
|
plt.savefig('total_loss.png')
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
out = model(src[:, :512].to(device), trg[:, :511].to(device))
|
|
|
|
idx_to_word = {idx: word for word, idx in word_to_idx.items()}
|
|
decoded_output = stringify(torch.argmax(out, dim=2)[0].cpu().numpy(), idx_to_word)
|
|
|
|
print(decoded_output)
|