Make minor refactoring

Signed-off-by: Pavel Korotkov <pavel@nspcc.ru>
This commit is contained in:
Pavel Korotkov 2021-04-06 12:06:11 +03:00 committed by Pavel Korotkov
parent f7007f2085
commit c21324bf77

View file

@ -2,26 +2,13 @@ package connections
import "math/rand" import "math/rand"
// https://www.keithschwarz.com/darts-dice-coins/ // See Vose's Alias Method (https://www.keithschwarz.com/darts-dice-coins/).
type Sampler struct { type Sampler struct {
randomGenerator *rand.Rand randomGenerator *rand.Rand
probabilities []float64 probabilities []float64
alias []int alias []int
} }
type workList []int
func (wl *workList) push(e int) {
*wl = append(*wl, e)
}
func (wl *workList) pop() int {
l := len(*wl) - 1
n := (*wl)[l]
*wl = (*wl)[:l]
return n
}
func NewSampler(probabilities []float64, source rand.Source) *Sampler { func NewSampler(probabilities []float64, source rand.Source) *Sampler {
sampler := &Sampler{} sampler := &Sampler{}
var ( var (
@ -39,28 +26,28 @@ func NewSampler(probabilities []float64, source rand.Source) *Sampler {
} }
for i, pi := range p { for i, pi := range p {
if pi < 1 { if pi < 1 {
small = append(small, i) small.add(i)
} else { } else {
large = append(large, i) large.add(i)
} }
} }
for len(large) > 0 && len(small) > 0 { for len(small) > 0 && len(large) > 0 {
l, g := small.pop(), large.pop() l, g := small.remove(), large.remove()
sampler.probabilities[l] = p[l] sampler.probabilities[l] = p[l]
sampler.alias[l] = g sampler.alias[l] = g
p[g] = (p[g] + p[l]) - 1 p[g] = p[g] + p[l] - 1
if p[g] < 1 { if p[g] < 1 {
small.push(g) small.add(g)
} else { } else {
large.push(g) large.add(g)
} }
} }
for len(large) > 0 { for len(large) > 0 {
g := large.pop() g := large.remove()
sampler.probabilities[g] = 1 sampler.probabilities[g] = 1
} }
for len(small) > 0 { for len(small) > 0 {
l := small.pop() l := small.remove()
sampler.probabilities[l] = 1 sampler.probabilities[l] = 1
} }
return sampler return sampler
@ -74,3 +61,16 @@ func (g *Sampler) Next() int {
} }
return g.alias[i] return g.alias[i]
} }
type workList []int
func (wl *workList) add(e int) {
*wl = append(*wl, e)
}
func (wl *workList) remove() int {
l := len(*wl) - 1
n := (*wl)[l]
*wl = (*wl)[:l]
return n
}