22 lines
511 B
Nim
22 lines
511 B
Nim
|
import std/algorithm
|
||
|
import random
|
||
|
|
||
|
import arraymancer
|
||
|
|
||
|
|
||
|
proc searchsorted[T](x: openArray[T], value: T, leftSide: static bool = true): int =
|
||
|
when leftSide:
|
||
|
result = x.lowerBound(value)
|
||
|
else:
|
||
|
result = x.upperBound(value)
|
||
|
|
||
|
proc sample*(probs: Tensor[float32]): int =
|
||
|
var
|
||
|
rng = initRand()
|
||
|
let
|
||
|
u = rng.rand(1.0'f32)
|
||
|
cdf = cumsum(probs, axis=0)
|
||
|
cdfA = cast[ptr UncheckedArray[float32]](cdf.unsafeRawOffset)
|
||
|
result = cdfA.toOpenArray(0, cdf.size-1).searchsorted(u, leftSide=false)
|
||
|
|