melite/generator.nim

22 lines
511 B
Nim
Raw Normal View History

2024-10-22 22:26:05 +00:00
import std/algorithm
import random
import arraymancer
proc searchsorted[T](x: openArray[T], value: T, leftSide: static bool = true): int =
when leftSide:
result = x.lowerBound(value)
else:
result = x.upperBound(value)
proc sample*(probs: Tensor[float32]): int =
var
rng = initRand()
let
u = rng.rand(1.0'f32)
cdf = cumsum(probs, axis=0)
cdfA = cast[ptr UncheckedArray[float32]](cdf.unsafeRawOffset)
result = cdfA.toOpenArray(0, cdf.size-1).searchsorted(u, leftSide=false)