From 7ce71001226ecdbac32db3f405afc7ef35f9dbfe Mon Sep 17 00:00:00 2001 From: Jonathan Dickinson Date: Thu, 27 Oct 2016 09:09:16 +0200 Subject: [PATCH] - Adding tests for MX round-robin (#358) - Implementing MX round-robin - Slight tidy --- middleware/loadbalance/loadbalance.go | 22 +++- middleware/loadbalance/loadbalance_test.go | 143 +++++++++++++++------ 2 files changed, 121 insertions(+), 44 deletions(-) diff --git a/middleware/loadbalance/loadbalance.go b/middleware/loadbalance/loadbalance.go index 59aad8a4f..7df0b31c6 100644 --- a/middleware/loadbalance/loadbalance.go +++ b/middleware/loadbalance/loadbalance.go @@ -28,6 +28,7 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error { func roundRobin(in []dns.RR) []dns.RR { cname := []dns.RR{} address := []dns.RR{} + mx := []dns.RR{} rest := []dns.RR{} for _, r := range in { switch r.Header().Rrtype { @@ -35,17 +36,29 @@ func roundRobin(in []dns.RR) []dns.RR { cname = append(cname, r) case dns.TypeA, dns.TypeAAAA: address = append(address, r) + case dns.TypeMX: + mx = append(mx, r) default: rest = append(rest, r) } } - switch l := len(address); l { + roundRobinShuffle(address) + roundRobinShuffle(mx) + + out := append(cname, rest...) + out = append(out, address...) + out = append(out, mx...) + return out +} + +func roundRobinShuffle(records []dns.RR) { + switch l := len(records); l { case 0, 1: break case 2: if dns.Id()%2 == 0 { - address[0], address[1] = address[1], address[0] + records[0], records[1] = records[1], records[0] } default: for j := 0; j < l*(int(dns.Id())%4+1); j++ { @@ -54,12 +67,9 @@ func roundRobin(in []dns.RR) []dns.RR { if q == p { p = (p + 1) % l } - address[q], address[p] = address[p], address[q] + records[q], records[p] = records[p], records[q] } } - out := append(cname, rest...) - out = append(out, address...) - return out } // Write implements the dns.ResponseWriter interface. diff --git a/middleware/loadbalance/loadbalance_test.go b/middleware/loadbalance/loadbalance_test.go index 5e240be13..2a5096004 100644 --- a/middleware/loadbalance/loadbalance_test.go +++ b/middleware/loadbalance/loadbalance_test.go @@ -16,44 +16,66 @@ func TestLoadBalance(t *testing.T) { // the first X records must be cnames after this test tests := []struct { - answer []dns.RR - extra []dns.RR - cnameAnswer int - cnameExtra int + answer []dns.RR + extra []dns.RR + cnameAnswer int + cnameExtra int + addressAnswer int + addressExtra int + mxAnswer int + mxExtra int }{ { answer: []dns.RR{ - newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), - newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), - newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), - newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), - newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), + newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), + newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + newMX("mx.region2.skydns.test. 300 IN MX 2 mx2.region2.skydns.test."), + newMX("mx.region2.skydns.test. 300 IN MX 3 mx3.region2.skydns.test."), }, - cnameAnswer: 4, + cnameAnswer: 4, + addressAnswer: 1, + mxAnswer: 3, }, { answer: []dns.RR{ - newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), - newCNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + newCNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), }, - cnameAnswer: 1, + cnameAnswer: 1, + addressAnswer: 1, + mxAnswer: 1, }, { answer: []dns.RR{ - newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), - newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), - newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), - newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + newMX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."), + newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + newMX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."), }, extra: []dns.RR{ - newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), - newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"), - newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), - newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), - newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"), + newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + newMX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."), + newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"), + newMX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."), }, - cnameAnswer: 1, - cnameExtra: 1, + cnameAnswer: 1, + cnameExtra: 1, + addressAnswer: 3, + addressExtra: 4, + mxAnswer: 3, + mxExtra: 3, }, } @@ -71,29 +93,73 @@ func TestLoadBalance(t *testing.T) { continue } - cname := 0 - for _, r := range rec.Msg.Answer { - if r.Header().Rrtype != dns.TypeCNAME { - break - } - cname++ + + cname, address, mx, sorted := countRecords(rec.Msg.Answer) + if !sorted { + t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i) } if cname != test.cnameAnswer { - t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname) + t.Errorf("Test %d: Expected %d CNAMEs in Answer, but got %d", i, test.cnameAnswer, cname) } - cname = 0 - for _, r := range rec.Msg.Extra { - if r.Header().Rrtype != dns.TypeCNAME { - break - } - cname++ + if address != test.addressAnswer { + t.Errorf("Test %d: Expected %d A/AAAAs in Answer, but got %d", i, test.addressAnswer, address) + } + if mx != test.mxAnswer { + t.Errorf("Test %d: Expected %d MXs in Answer, but got %d", i, test.mxAnswer, mx) + } + + cname, address, mx, sorted = countRecords(rec.Msg.Extra) + if !sorted { + t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Extra, but got mixed", i) } if cname != test.cnameExtra { - t.Errorf("Test %d: Expected %d cname in Extra, but got %d", i, test.cnameExtra, cname) + t.Errorf("Test %d: Expected %d CNAMEs in Extra, but got %d", i, test.cnameAnswer, cname) + } + if address != test.addressExtra { + t.Errorf("Test %d: Expected %d A/AAAAs in Extra, but got %d", i, test.addressAnswer, address) + } + if mx != test.mxExtra { + t.Errorf("Test %d: Expected %d MXs in Extra, but got %d", i, test.mxAnswer, mx) } } } +func countRecords(result []dns.RR) (cname int, address int, mx int, sorted bool) { + const ( + Start = iota + CNAMERecords + ARecords + MXRecords + Any + ) + + // The order of the records is used to determine if the round-robin actually did anything. + sorted = true + cname = 0 + address = 0 + mx = 0 + state := Start + for _, r := range result { + switch r.Header().Rrtype { + case dns.TypeCNAME: + sorted = sorted && state <= CNAMERecords + state = CNAMERecords + cname++ + case dns.TypeA, dns.TypeAAAA: + sorted = sorted && state <= ARecords + state = ARecords + address++ + case dns.TypeMX: + sorted = sorted && state <= MXRecords + state = MXRecords + mx++ + default: + state = Any + } + } + return +} + func handler() middleware.Handler { return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { w.WriteMsg(r) @@ -104,3 +170,4 @@ func handler() middleware.Handler { func newA(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) } func newAAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) } func newCNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) } +func newMX(rr string) *dns.MX { r, _ := dns.NewRR(rr); return r.(*dns.MX) }