133 lines
4 KiB
Nim
Executable file
133 lines
4 KiB
Nim
Executable file
import std / [ tables, os, strformat ]
|
|
import random
|
|
|
|
import arraymancer
|
|
|
|
import ./batcher
|
|
import ./hparams
|
|
import ./generator
|
|
import ./textEncoder
|
|
|
|
|
|
randomize()
|
|
|
|
###### Text encoding
|
|
let vocabSize: int = stringToInt.len()
|
|
|
|
var encodedText: seq[int] = encodeString(textContent, stringToInt)
|
|
|
|
###### Split corpus into training and validation sets #######
|
|
const perchentageTraining = 80 # how much % of the corpus is given for training.
|
|
|
|
let trainingSetEnd:int = (perchentageTraining*encodedText.len/100).int
|
|
|
|
let trainingSet: seq[int] = encodedText[0..trainingSetEnd]
|
|
|
|
let validationSet: seq[int] = encodedText[trainingSetEnd..textContent.len-1]
|
|
|
|
###### Define NN
|
|
let ctx = newContext Tensor[float32]
|
|
|
|
network Nimertes:
|
|
layers:
|
|
encoder: Embedding(vocabSize, hiddenSize)
|
|
hiddenLinear: Linear(hiddenSize, hiddenSize)
|
|
outputLayer: Linear(hiddenSize, vocabSize)
|
|
forward x:
|
|
x.encoder.tanh.hiddenLinear.tanh.hiddenLinear.tanh.outputLayer
|
|
|
|
###### Save/Load Model
|
|
proc saveModel(ctx: Context[AnyTensor[float32]], model: Nimertes, dir: string) =
|
|
echo "\nsaving model..."
|
|
for layer, layerField in model.fieldPairs:
|
|
var layerName = layer
|
|
for field, tensorVariable in layerField.fieldPairs:
|
|
var fieldName = field
|
|
when tensorVariable is Variable[Tensor[float32]]:
|
|
tensorVariable.value.writeNPY(dir/fmt"{layerName}_{fieldName}.npy")
|
|
else:
|
|
discard
|
|
echo "model saved"
|
|
|
|
proc initModel(ctx: Context[AnyTensor[float32]], model: Nimertes, dir: string): Nimertes =
|
|
echo "\nweights exist"
|
|
echo "\nloading model..."
|
|
for layer, _ in model.fieldPairs:
|
|
var layerName = layer
|
|
case layerName
|
|
of "encoder":
|
|
model.encoder.weight.value = readNPY[float32](dir/fmt"{layerName}_weight.npy")
|
|
of "hiddenLinear":
|
|
model.hiddenLinear.weight.value = readNPY[float32](dir/fmt"{layerName}_weight.npy")
|
|
model.hiddenLinear.bias.value = readNPY[float32](dir/fmt"{layerName}_bias.npy")
|
|
of "outputLinear":
|
|
model.outputLayer.weight.value = readNPY[float32](dir/fmt"{layerName}_weight.npy")
|
|
model.outputLayer.bias.value = readNPY[float32](dir/fmt"{layerName}_bias.npy")
|
|
echo "model loaded\n"
|
|
return model
|
|
|
|
#### Initialize NN
|
|
var
|
|
model = ctx.init(Nimertes)
|
|
optim = model.optimizer(Adam, learningRate=3e-4'f32, beta1=0.9'f32, beta2=0.9'f32, eps=1e-5'f32)
|
|
|
|
if fileExists("tinyBiGram/encoder_weight.npy"):
|
|
model = ctx.initModel(model, "tinyBiGram")
|
|
|
|
###### Generate Text
|
|
proc generateText(ctx: Context[AnyTensor[float32]], model: Nimertes, seedCharacters="Wh", seqLen=blockSize, temperature=0.8'f32): string =
|
|
|
|
ctx.no_grad_mode:
|
|
let primer = encodeString(seedCharacters, stringToInt).toTensor.unsqueeze(1)
|
|
|
|
result = seedCharacters
|
|
|
|
var
|
|
input = primer[^1, _]
|
|
output: Variable[Tensor[float32]]
|
|
|
|
for _ in 0 ..< seqLen:
|
|
output = model.forward(input.squeeze(0))
|
|
var preds = output.value
|
|
|
|
preds /.= temperature
|
|
let probs = preds.softmax().squeeze(0)
|
|
|
|
# Sample and append to result
|
|
let encodedChar = probs.sample()
|
|
result &= decodeString(encodedChar, intToString)
|
|
|
|
input = newTensor[int](1,1)
|
|
input[0, 0] = encodedChar
|
|
|
|
###### Training
|
|
var totalLoss: seq[float]
|
|
var plotidx : seq[float]
|
|
|
|
for i in 0..numEpochs:
|
|
var
|
|
(trainingBatch, trainingBatchNext): (Tensor[int], Tensor[int]) = getBatch("train", trainingSet, validationSet)
|
|
output: Variable[Tensor[float32]]
|
|
batchLoss: Variable[Tensor[float32]]
|
|
|
|
if i %% evalIter == 0:
|
|
echo "\n", ctx.generateText(model), "\n"
|
|
ctx.saveModel(model, "tinyBiGram")
|
|
else:
|
|
for i in 0 ..< batchSize:
|
|
var
|
|
inputTensor: Tensor[int] = trainingBatch[i, _]
|
|
targetTensor: Tensor[int] = trainingBatchNext[i, _]
|
|
|
|
output = model.forward(inputTensor.squeeze(0))
|
|
batchLoss = output.sparseSoftmaxCrossEntropy(target=targetTensor.squeeze(0))
|
|
|
|
batchLoss.backprop()
|
|
optim.update()
|
|
|
|
totalLoss.add(batchLoss.value[0])
|
|
plotidx.add(i.float)
|
|
|
|
###### Plot results and show final output
|
|
echo ctx.generateText(model)
|
|
|