melite/gpt.nim

66 lines
1.6 KiB
Nim
Raw Normal View History

2024-10-22 22:26:05 +00:00
import std / [ tables, os ]
import random
import arraymancer
import plotly
import progress
import therapist
import ./batcher
import ./hparams
import ./textEncoder
import ./generator
randomize()
let spec = (
dir: newStringArg(@["-d", "--dir"], defaultVal="defaultDir", help="Directory to save/load from."),
help: newHelpArg(@["-h", "--help"], help="Show help message"),
)
spec.parseOrQuit("Nimertes")
let dirName = spec.dir.value
var
bar = newProgressBar(total=numEpochs)
bar.start()
###### Text encoding
let vocabSize: int = stringToInt.len()
# var encodedText: seq[int] = encodeString(textContent, stringToInt)
# ###### Split corpus into training and validation sets #######
# const perchentageTraining = 80 # how much % of the corpus is given for training.
# let trainingSetEnd:int = (perchentageTraining*encodedText.len/100).int
# let trainingSet: seq[int] = encodedText[0..trainingSetEnd]
# let validationSet: seq[int] = encodedText[trainingSetEnd..textContent.len-1]
###### Define NN
let ctx = newContext Tensor[float32]
# TODO: make Block type for Nimertes
# type Block:
network NimertesGPT:
layers:
tokenEmbedder: Embedding(vocabSize, hiddenSize)
positionEmbedder: Embedding(blockSize, nEmbed)
# blockLayer: Block(nEmbed,)
languageModelHead: Linear(nEmbed, vocabSize)
hiddenLinear: Linear(hiddenSize, hiddenSize)
outputLayer: Linear(hiddenSize, nEmbed)
forward x:
tokenEmbedding = x.tokenEmbedder()
positionEmbedding = .positionEmbedder()
x.tokenEmbedding.positionEmbedding.tanh.hiddenLinear.tanh.hiddenLinear.tanh.outputLayer
###### Initialize NN
var
model = ctx.init(NimertesGPT)