frostfs-sdk-go/pool/sampler.go

93 lines
1.9 KiB
Go
Raw Permalink Normal View History

package pool
import (
"math/rand"
"sync"
)
// sampler implements weighted random number generation using Vose's Alias
// Method (https://www.keithschwarz.com/darts-dice-coins/).
type sampler struct {
mu sync.Mutex
randomGenerator *rand.Rand
probabilities []float64
alias []int
}
// newSampler creates new sampler with a given set of probabilities using
// given source of randomness. Created sampler will produce numbers from
// 0 to len(probabilities).
func newSampler(probabilities []float64, source rand.Source) *sampler {
sampler := &sampler{}
var (
small workList
large workList
)
n := len(probabilities)
sampler.randomGenerator = rand.New(source)
sampler.probabilities = make([]float64, n)
sampler.alias = make([]int, n)
// Compute scaled probabilities.
p := make([]float64, n)
for i := range n {
p[i] = probabilities[i] * float64(n)
}
for i, pi := range p {
if pi < 1 {
small.add(i)
} else {
large.add(i)
}
}
for len(small) > 0 && len(large) > 0 {
l, g := small.remove(), large.remove()
sampler.probabilities[l] = p[l]
sampler.alias[l] = g
p[g] = p[g] + p[l] - 1
if p[g] < 1 {
small.add(g)
} else {
large.add(g)
}
}
for len(large) > 0 {
g := large.remove()
sampler.probabilities[g] = 1
}
for len(small) > 0 {
l := small.remove()
sampler.probabilities[l] = 1
}
return sampler
}
// Next returns the next (not so) random number from sampler.
// This method is safe for concurrent use by multiple goroutines.
func (g *sampler) Next() int {
n := len(g.alias)
g.mu.Lock()
i := g.randomGenerator.Intn(n)
f := g.randomGenerator.Float64()
g.mu.Unlock()
if f < g.probabilities[i] {
return i
}
return g.alias[i]
}
type workList []int
func (wl *workList) add(e int) {
*wl = append(*wl, e)
}
func (wl *workList) remove() int {
l := len(*wl) - 1
n := (*wl)[l]
*wl = (*wl)[:l]
return n
}