coredns/plugin/loadbalance/weighted.go
Zhizhen He 5de473da1c
fix: remove unnecessary conversion (#6258)
Signed-off-by: Zhizhen He <hezhizhen.yi@gmail.com>
2023-08-14 15:14:09 +02:00

329 lines
7.1 KiB
Go

package loadbalance
import (
"bufio"
"bytes"
"crypto/md5"
"errors"
"fmt"
"io"
"math/rand"
"net"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
type (
// "weighted-round-robin" policy specific data
weightedRR struct {
fileName string
reload time.Duration
md5sum [md5.Size]byte
domains map[string]weights
randomGen
mutex sync.Mutex
}
// Per domain weights
weights []*weightItem
// Weight assigned to an address
weightItem struct {
address net.IP
value uint8
}
// Random uint generator
randomGen interface {
randInit()
randUint(limit uint) uint
}
)
// Random uint generator
type randomUint struct {
rn *rand.Rand
}
func (r *randomUint) randInit() {
r.rn = rand.New(rand.NewSource(time.Now().UnixNano()))
}
func (r *randomUint) randUint(limit uint) uint {
return uint(r.rn.Intn(int(limit)))
}
func weightedShuffle(res *dns.Msg, w *weightedRR) *dns.Msg {
switch res.Question[0].Qtype {
case dns.TypeA, dns.TypeAAAA, dns.TypeSRV:
res.Answer = w.weightedRoundRobin(res.Answer)
res.Extra = w.weightedRoundRobin(res.Extra)
}
return res
}
func weightedOnStartUp(w *weightedRR, stopReloadChan chan bool) error {
err := w.updateWeights()
if errors.Is(err, errOpen) && w.reload != 0 {
log.Warningf("Failed to open weight file:%v. Will try again in %v",
err, w.reload)
} else if err != nil {
return plugin.Error("loadbalance", err)
}
// start periodic weight file reload go routine
w.periodicWeightUpdate(stopReloadChan)
return nil
}
func createWeightedFuncs(weightFileName string,
reload time.Duration) *lbFuncs {
lb := &lbFuncs{
weighted: &weightedRR{
fileName: weightFileName,
reload: reload,
randomGen: &randomUint{},
},
}
lb.weighted.randomGen.randInit()
lb.shuffleFunc = func(res *dns.Msg) *dns.Msg {
return weightedShuffle(res, lb.weighted)
}
stopReloadChan := make(chan bool)
lb.onStartUpFunc = func() error {
return weightedOnStartUp(lb.weighted, stopReloadChan)
}
lb.onShutdownFunc = func() error {
// stop periodic weigh reload go routine
close(stopReloadChan)
return nil
}
return lb
}
// Apply weighted round robin policy to the answer
func (w *weightedRR) weightedRoundRobin(in []dns.RR) []dns.RR {
cname := []dns.RR{}
address := []dns.RR{}
mx := []dns.RR{}
rest := []dns.RR{}
for _, r := range in {
switch r.Header().Rrtype {
case dns.TypeCNAME:
cname = append(cname, r)
case dns.TypeA, dns.TypeAAAA:
address = append(address, r)
case dns.TypeMX:
mx = append(mx, r)
default:
rest = append(rest, r)
}
}
if len(address) == 0 {
// no change
return in
}
w.setTopRecord(address)
out := append(cname, rest...)
out = append(out, address...)
out = append(out, mx...)
return out
}
// Move the next expected address to the first position in the result list
func (w *weightedRR) setTopRecord(address []dns.RR) {
itop := w.topAddressIndex(address)
if itop < 0 {
// internal error
return
}
if itop != 0 {
// swap the selected top entry with the actual one
address[0], address[itop] = address[itop], address[0]
}
}
// Compute the top (first) address index
func (w *weightedRR) topAddressIndex(address []dns.RR) int {
w.mutex.Lock()
defer w.mutex.Unlock()
// Determine the weight value for each address in the answer
var wsum uint
type waddress struct {
index int
weight uint8
}
weightedAddr := make([]waddress, len(address))
for i, ar := range address {
wa := &weightedAddr[i]
wa.index = i
wa.weight = 1 // default weight
var ip net.IP
switch ar.Header().Rrtype {
case dns.TypeA:
ip = ar.(*dns.A).A
case dns.TypeAAAA:
ip = ar.(*dns.AAAA).AAAA
}
ws := w.domains[ar.Header().Name]
for _, w := range ws {
if w.address.Equal(ip) {
wa.weight = w.value
break
}
}
wsum += uint(wa.weight)
}
// Select the first (top) IP
sort.Slice(weightedAddr, func(i, j int) bool {
return weightedAddr[i].weight > weightedAddr[j].weight
})
v := w.randUint(wsum)
var psum uint
for _, wa := range weightedAddr {
psum += uint(wa.weight)
if v < psum {
return wa.index
}
}
// we should never reach this
log.Errorf("Internal error: cannot find top address (randv:%v wsum:%v)", v, wsum)
return -1
}
// Start go routine to update weights from the weight file periodically
func (w *weightedRR) periodicWeightUpdate(stopReload <-chan bool) {
if w.reload == 0 {
return
}
go func() {
ticker := time.NewTicker(w.reload)
for {
select {
case <-stopReload:
return
case <-ticker.C:
err := w.updateWeights()
if err != nil {
log.Error(err)
}
}
}
}()
}
// Update weights from weight file
func (w *weightedRR) updateWeights() error {
reader, err := os.Open(filepath.Clean(w.fileName))
if err != nil {
return errOpen
}
defer reader.Close()
// check if the contents has changed
var buf bytes.Buffer
tee := io.TeeReader(reader, &buf)
bytes, err := io.ReadAll(tee)
if err != nil {
return err
}
md5sum := md5.Sum(bytes)
if md5sum == w.md5sum {
// file contents has not changed
return nil
}
w.md5sum = md5sum
scanner := bufio.NewScanner(&buf)
// Parse the weight file contents
domains, err := w.parseWeights(scanner)
if err != nil {
return err
}
// access to weights must be protected
w.mutex.Lock()
w.domains = domains
w.mutex.Unlock()
log.Infof("Successfully reloaded weight file %s", w.fileName)
return nil
}
// Parse the weight file contents
func (w *weightedRR) parseWeights(scanner *bufio.Scanner) (map[string]weights, error) {
var dname string
var ws weights
domains := make(map[string]weights)
for scanner.Scan() {
nextLine := strings.TrimSpace(scanner.Text())
if len(nextLine) == 0 || nextLine[0:1] == "#" {
// Empty and comment lines are ignored
continue
}
fields := strings.Fields(nextLine)
switch len(fields) {
case 1:
// (domain) name sanity check
if net.ParseIP(fields[0]) != nil {
return nil, fmt.Errorf("Wrong domain name:\"%s\" in weight file %s. (Maybe a missing weight value?)",
fields[0], w.fileName)
}
dname = fields[0]
// add the root domain if it is missing
if dname[len(dname)-1] != '.' {
dname += "."
}
var ok bool
ws, ok = domains[dname]
if !ok {
ws = make(weights, 0)
domains[dname] = ws
}
case 2:
// IP address and weight value
ip := net.ParseIP(fields[0])
if ip == nil {
return nil, fmt.Errorf("Wrong IP address:\"%s\" in weight file %s", fields[0], w.fileName)
}
weight, err := strconv.ParseUint(fields[1], 10, 8)
if err != nil || weight == 0 {
return nil, fmt.Errorf("Wrong weight value:\"%s\" in weight file %s", fields[1], w.fileName)
}
witem := &weightItem{address: ip, value: uint8(weight)}
if dname == "" {
return nil, fmt.Errorf("Missing domain name in weight file %s", w.fileName)
}
ws = append(ws, witem)
domains[dname] = ws
default:
return nil, fmt.Errorf("Could not parse weight line:\"%s\" in weight file %s", nextLine, w.fileName)
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("Weight file %s parsing error:%s", w.fileName, err)
}
return domains, nil
}