21 lines
511 B
Nim
Executable file
21 lines
511 B
Nim
Executable file
import std/algorithm
|
|
import random
|
|
|
|
import arraymancer
|
|
|
|
|
|
proc searchsorted[T](x: openArray[T], value: T, leftSide: static bool = true): int =
|
|
when leftSide:
|
|
result = x.lowerBound(value)
|
|
else:
|
|
result = x.upperBound(value)
|
|
|
|
proc sample*(probs: Tensor[float32]): int =
|
|
var
|
|
rng = initRand()
|
|
let
|
|
u = rng.rand(1.0'f32)
|
|
cdf = cumsum(probs, axis=0)
|
|
cdfA = cast[ptr UncheckedArray[float32]](cdf.unsafeRawOffset)
|
|
result = cdfA.toOpenArray(0, cdf.size-1).searchsorted(u, leftSide=false)
|
|
|