Make minor refactoring
Signed-off-by: Pavel Korotkov <pavel@nspcc.ru>
This commit is contained in:
parent
f7007f2085
commit
c21324bf77
1 changed files with 23 additions and 23 deletions
|
@ -2,26 +2,13 @@ package connections
|
||||||
|
|
||||||
import "math/rand"
|
import "math/rand"
|
||||||
|
|
||||||
// https://www.keithschwarz.com/darts-dice-coins/
|
// See Vose's Alias Method (https://www.keithschwarz.com/darts-dice-coins/).
|
||||||
type Sampler struct {
|
type Sampler struct {
|
||||||
randomGenerator *rand.Rand
|
randomGenerator *rand.Rand
|
||||||
probabilities []float64
|
probabilities []float64
|
||||||
alias []int
|
alias []int
|
||||||
}
|
}
|
||||||
|
|
||||||
type workList []int
|
|
||||||
|
|
||||||
func (wl *workList) push(e int) {
|
|
||||||
*wl = append(*wl, e)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (wl *workList) pop() int {
|
|
||||||
l := len(*wl) - 1
|
|
||||||
n := (*wl)[l]
|
|
||||||
*wl = (*wl)[:l]
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSampler(probabilities []float64, source rand.Source) *Sampler {
|
func NewSampler(probabilities []float64, source rand.Source) *Sampler {
|
||||||
sampler := &Sampler{}
|
sampler := &Sampler{}
|
||||||
var (
|
var (
|
||||||
|
@ -39,28 +26,28 @@ func NewSampler(probabilities []float64, source rand.Source) *Sampler {
|
||||||
}
|
}
|
||||||
for i, pi := range p {
|
for i, pi := range p {
|
||||||
if pi < 1 {
|
if pi < 1 {
|
||||||
small = append(small, i)
|
small.add(i)
|
||||||
} else {
|
} else {
|
||||||
large = append(large, i)
|
large.add(i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for len(large) > 0 && len(small) > 0 {
|
for len(small) > 0 && len(large) > 0 {
|
||||||
l, g := small.pop(), large.pop()
|
l, g := small.remove(), large.remove()
|
||||||
sampler.probabilities[l] = p[l]
|
sampler.probabilities[l] = p[l]
|
||||||
sampler.alias[l] = g
|
sampler.alias[l] = g
|
||||||
p[g] = (p[g] + p[l]) - 1
|
p[g] = p[g] + p[l] - 1
|
||||||
if p[g] < 1 {
|
if p[g] < 1 {
|
||||||
small.push(g)
|
small.add(g)
|
||||||
} else {
|
} else {
|
||||||
large.push(g)
|
large.add(g)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for len(large) > 0 {
|
for len(large) > 0 {
|
||||||
g := large.pop()
|
g := large.remove()
|
||||||
sampler.probabilities[g] = 1
|
sampler.probabilities[g] = 1
|
||||||
}
|
}
|
||||||
for len(small) > 0 {
|
for len(small) > 0 {
|
||||||
l := small.pop()
|
l := small.remove()
|
||||||
sampler.probabilities[l] = 1
|
sampler.probabilities[l] = 1
|
||||||
}
|
}
|
||||||
return sampler
|
return sampler
|
||||||
|
@ -74,3 +61,16 @@ func (g *Sampler) Next() int {
|
||||||
}
|
}
|
||||||
return g.alias[i]
|
return g.alias[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type workList []int
|
||||||
|
|
||||||
|
func (wl *workList) add(e int) {
|
||||||
|
*wl = append(*wl, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wl *workList) remove() int {
|
||||||
|
l := len(*wl) - 1
|
||||||
|
n := (*wl)[l]
|
||||||
|
*wl = (*wl)[:l]
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue