26 lines
719 B
Python
26 lines
719 B
Python
import re
|
|
from collections import Counter
|
|
|
|
import torch
|
|
|
|
import hparams
|
|
|
|
def tokenize(filename):
|
|
with open(filename, 'r') as f:
|
|
text = f.read()
|
|
tokens = re.findall(r'\S+', text.lower())
|
|
return tokens
|
|
|
|
def build_vocab(tokens, max_vocab_size):
|
|
freq = Counter(tokens)
|
|
vocab = sorted(freq, key=freq.get, reverse=True)[:max_vocab_size]
|
|
vocab.insert(0, "<PAD>")
|
|
vocab.insert(1, "<UNK>")
|
|
word_to_idx = {word: idx for idx, word in enumerate(vocab)}
|
|
return word_to_idx
|
|
|
|
def numericalize(tokens, word_to_idx):
|
|
return [word_to_idx.get(token, word_to_idx["<UNK>"]) for token in tokens]
|
|
|
|
def stringify(indices, idx_to_word):
|
|
return ' '.join([idx_to_word[idx] for idx in indices])
|