- Adding tests for MX round-robin (#358)

- Implementing MX round-robin
- Slight tidy
This commit is contained in:
Jonathan Dickinson 2016-10-27 09:09:16 +02:00 committed by Miek Gieben
parent 219bfd0493
commit 7ce7100122
2 changed files with 121 additions and 44 deletions

View file

@ -28,6 +28,7 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error {
func roundRobin(in []dns.RR) []dns.RR { func roundRobin(in []dns.RR) []dns.RR {
cname := []dns.RR{} cname := []dns.RR{}
address := []dns.RR{} address := []dns.RR{}
mx := []dns.RR{}
rest := []dns.RR{} rest := []dns.RR{}
for _, r := range in { for _, r := range in {
switch r.Header().Rrtype { switch r.Header().Rrtype {
@ -35,17 +36,29 @@ func roundRobin(in []dns.RR) []dns.RR {
cname = append(cname, r) cname = append(cname, r)
case dns.TypeA, dns.TypeAAAA: case dns.TypeA, dns.TypeAAAA:
address = append(address, r) address = append(address, r)
case dns.TypeMX:
mx = append(mx, r)
default: default:
rest = append(rest, r) rest = append(rest, r)
} }
} }
switch l := len(address); l { roundRobinShuffle(address)
roundRobinShuffle(mx)
out := append(cname, rest...)
out = append(out, address...)
out = append(out, mx...)
return out
}
func roundRobinShuffle(records []dns.RR) {
switch l := len(records); l {
case 0, 1: case 0, 1:
break break
case 2: case 2:
if dns.Id()%2 == 0 { if dns.Id()%2 == 0 {
address[0], address[1] = address[1], address[0] records[0], records[1] = records[1], records[0]
} }
default: default:
for j := 0; j < l*(int(dns.Id())%4+1); j++ { for j := 0; j < l*(int(dns.Id())%4+1); j++ {
@ -54,12 +67,9 @@ func roundRobin(in []dns.RR) []dns.RR {
if q == p { if q == p {
p = (p + 1) % l p = (p + 1) % l
} }
address[q], address[p] = address[p], address[q] records[q], records[p] = records[p], records[q]
} }
} }
out := append(cname, rest...)
out = append(out, address...)
return out
} }
// Write implements the dns.ResponseWriter interface. // Write implements the dns.ResponseWriter interface.

View file

@ -16,44 +16,66 @@ func TestLoadBalance(t *testing.T) {
// the first X records must be cnames after this test // the first X records must be cnames after this test
tests := []struct { tests := []struct {
answer []dns.RR answer []dns.RR
extra []dns.RR extra []dns.RR
cnameAnswer int cnameAnswer int
cnameExtra int cnameExtra int
addressAnswer int
addressExtra int
mxAnswer int
mxExtra int
}{ }{
{ {
answer: []dns.RR{ answer: []dns.RR{
newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."),
newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
newMX("mx.region2.skydns.test. 300 IN MX 2 mx2.region2.skydns.test."),
newMX("mx.region2.skydns.test. 300 IN MX 3 mx3.region2.skydns.test."),
}, },
cnameAnswer: 4, cnameAnswer: 4,
addressAnswer: 1,
mxAnswer: 3,
}, },
{ {
answer: []dns.RR{ answer: []dns.RR{
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
newCNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
newCNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
}, },
cnameAnswer: 1, cnameAnswer: 1,
addressAnswer: 1,
mxAnswer: 1,
}, },
{ {
answer: []dns.RR{ answer: []dns.RR{
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), newMX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
newMX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."),
}, },
extra: []dns.RR{ extra: []dns.RR{
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"), newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"), newMX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"),
newMX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."),
}, },
cnameAnswer: 1, cnameAnswer: 1,
cnameExtra: 1, cnameExtra: 1,
addressAnswer: 3,
addressExtra: 4,
mxAnswer: 3,
mxExtra: 3,
}, },
} }
@ -71,29 +93,73 @@ func TestLoadBalance(t *testing.T) {
continue continue
} }
cname := 0
for _, r := range rec.Msg.Answer { cname, address, mx, sorted := countRecords(rec.Msg.Answer)
if r.Header().Rrtype != dns.TypeCNAME { if !sorted {
break t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i)
}
cname++
} }
if cname != test.cnameAnswer { if cname != test.cnameAnswer {
t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname) t.Errorf("Test %d: Expected %d CNAMEs in Answer, but got %d", i, test.cnameAnswer, cname)
} }
cname = 0 if address != test.addressAnswer {
for _, r := range rec.Msg.Extra { t.Errorf("Test %d: Expected %d A/AAAAs in Answer, but got %d", i, test.addressAnswer, address)
if r.Header().Rrtype != dns.TypeCNAME { }
break if mx != test.mxAnswer {
} t.Errorf("Test %d: Expected %d MXs in Answer, but got %d", i, test.mxAnswer, mx)
cname++ }
cname, address, mx, sorted = countRecords(rec.Msg.Extra)
if !sorted {
t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Extra, but got mixed", i)
} }
if cname != test.cnameExtra { if cname != test.cnameExtra {
t.Errorf("Test %d: Expected %d cname in Extra, but got %d", i, test.cnameExtra, cname) t.Errorf("Test %d: Expected %d CNAMEs in Extra, but got %d", i, test.cnameAnswer, cname)
}
if address != test.addressExtra {
t.Errorf("Test %d: Expected %d A/AAAAs in Extra, but got %d", i, test.addressAnswer, address)
}
if mx != test.mxExtra {
t.Errorf("Test %d: Expected %d MXs in Extra, but got %d", i, test.mxAnswer, mx)
} }
} }
} }
func countRecords(result []dns.RR) (cname int, address int, mx int, sorted bool) {
const (
Start = iota
CNAMERecords
ARecords
MXRecords
Any
)
// The order of the records is used to determine if the round-robin actually did anything.
sorted = true
cname = 0
address = 0
mx = 0
state := Start
for _, r := range result {
switch r.Header().Rrtype {
case dns.TypeCNAME:
sorted = sorted && state <= CNAMERecords
state = CNAMERecords
cname++
case dns.TypeA, dns.TypeAAAA:
sorted = sorted && state <= ARecords
state = ARecords
address++
case dns.TypeMX:
sorted = sorted && state <= MXRecords
state = MXRecords
mx++
default:
state = Any
}
}
return
}
func handler() middleware.Handler { func handler() middleware.Handler {
return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
w.WriteMsg(r) w.WriteMsg(r)
@ -104,3 +170,4 @@ func handler() middleware.Handler {
func newA(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) } func newA(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) }
func newAAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) } func newAAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) }
func newCNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) } func newCNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) }
func newMX(rr string) *dns.MX { r, _ := dns.NewRR(rr); return r.(*dns.MX) }