diff --git a/connections/sampler.go b/connections/sampler.go index 8dced58..e6439af 100644 --- a/connections/sampler.go +++ b/connections/sampler.go @@ -2,26 +2,13 @@ package connections import "math/rand" -// https://www.keithschwarz.com/darts-dice-coins/ +// See Vose's Alias Method (https://www.keithschwarz.com/darts-dice-coins/). type Sampler struct { randomGenerator *rand.Rand probabilities []float64 alias []int } -type workList []int - -func (wl *workList) push(e int) { - *wl = append(*wl, e) -} - -func (wl *workList) pop() int { - l := len(*wl) - 1 - n := (*wl)[l] - *wl = (*wl)[:l] - return n -} - func NewSampler(probabilities []float64, source rand.Source) *Sampler { sampler := &Sampler{} var ( @@ -39,28 +26,28 @@ func NewSampler(probabilities []float64, source rand.Source) *Sampler { } for i, pi := range p { if pi < 1 { - small = append(small, i) + small.add(i) } else { - large = append(large, i) + large.add(i) } } - for len(large) > 0 && len(small) > 0 { - l, g := small.pop(), large.pop() + for len(small) > 0 && len(large) > 0 { + l, g := small.remove(), large.remove() sampler.probabilities[l] = p[l] sampler.alias[l] = g - p[g] = (p[g] + p[l]) - 1 + p[g] = p[g] + p[l] - 1 if p[g] < 1 { - small.push(g) + small.add(g) } else { - large.push(g) + large.add(g) } } for len(large) > 0 { - g := large.pop() + g := large.remove() sampler.probabilities[g] = 1 } for len(small) > 0 { - l := small.pop() + l := small.remove() sampler.probabilities[l] = 1 } return sampler @@ -74,3 +61,16 @@ func (g *Sampler) Next() int { } return g.alias[i] } + +type workList []int + +func (wl *workList) add(e int) { + *wl = append(*wl, e) +} + +func (wl *workList) remove() int { + l := len(*wl) - 1 + n := (*wl)[l] + *wl = (*wl)[:l] + return n +}