diff --git a/middleware/etcd/etcd.go b/middleware/etcd/etcd.go index 1b98d69bc..84b884bcb 100644 --- a/middleware/etcd/etcd.go +++ b/middleware/etcd/etcd.go @@ -103,7 +103,7 @@ Nodes: if err := json.Unmarshal([]byte(n.Value), serv); err != nil { return nil, err } - b := msg.Service{Host: serv.Host, Port: serv.Port, Priority: serv.Priority, Weight: serv.Weight, Text: serv.Text} + b := msg.Service{Host: serv.Host, Port: serv.Port, Priority: serv.Priority, Weight: serv.Weight, Text: serv.Text, Key: n.Key} if _, ok := bx[b]; ok { continue } diff --git a/middleware/etcd/lookup.go b/middleware/etcd/lookup.go index 43dd6b56b..c996b986c 100644 --- a/middleware/etcd/lookup.go +++ b/middleware/etcd/lookup.go @@ -10,13 +10,20 @@ import ( "github.com/miekg/dns" ) -func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR) (records []dns.RR, err error) { - services, err := e.Records(state.Name(), false) +func (e Etcd) records(state middleware.State, exact bool) ([]msg.Service, error) { + services, err := e.Records(state.Name(), exact) if err != nil { return nil, err } - services = msg.Group(services) + return services, nil +} + +func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR) (records []dns.RR, err error) { + services, err := e.records(state, false) + if err != nil { + return nil, err + } for _, serv := range services { ip := net.ParseIP(serv.Host) @@ -73,13 +80,11 @@ func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR) ( } func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR) (records []dns.RR, err error) { - services, err := e.Records(state.Name(), false) + services, err := e.records(state, false) if err != nil { return nil, err } - services = msg.Group(services) - for _, serv := range services { ip := net.ParseIP(serv.Host) switch { @@ -139,13 +144,11 @@ func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR // SRV returns SRV records from etcd. // If the Target is not a name but an IP address, a name is created on the fly. func (e Etcd) SRV(zone string, state middleware.State) (records []dns.RR, extra []dns.RR, err error) { - services, err := e.Records(state.Name(), false) + services, err := e.records(state, false) if err != nil { return nil, nil, err } - services = msg.Group(services) - // Looping twice to get the right weight vs priority w := make(map[int]int) for _, serv := range services { @@ -224,7 +227,7 @@ func (e Etcd) SRV(zone string, state middleware.State) (records []dns.RR, extra // MX returns MX records from etcd. // If the Target is not a name but an IP address, a name is created on the fly. func (e Etcd) MX(zone string, state middleware.State) (records []dns.RR, extra []dns.RR, err error) { - services, err := e.Records(state.Name(), false) + services, err := e.records(state, false) if err != nil { return nil, nil, err } @@ -282,13 +285,11 @@ func (e Etcd) MX(zone string, state middleware.State) (records []dns.RR, extra [ } func (e Etcd) CNAME(zone string, state middleware.State) (records []dns.RR, err error) { - services, err := e.Records(state.Name(), true) + services, err := e.records(state, true) if err != nil { return nil, err } - services = msg.Group(services) - if len(services) > 0 { serv := services[0] if ip := net.ParseIP(serv.Host); ip == nil { @@ -299,13 +300,11 @@ func (e Etcd) CNAME(zone string, state middleware.State) (records []dns.RR, err } func (e Etcd) TXT(zone string, state middleware.State) (records []dns.RR, err error) { - services, err := e.Records(state.Name(), false) + services, err := e.records(state, false) if err != nil { return nil, err } - services = msg.Group(services) - for _, serv := range services { if serv.Text == "" { continue diff --git a/middleware/etcd/lookup_test.go b/middleware/etcd/lookup_test.go index 5efe6a314..56e674047 100644 --- a/middleware/etcd/lookup_test.go +++ b/middleware/etcd/lookup_test.go @@ -97,18 +97,20 @@ var dnsTestCases = []dnsTestCase{ { Qname: "*.region1.skydns.test.", Qtype: dns.TypeSRV, Answer: []dns.RR{ - newSRV("*.region1.skydns.test. 300 IN SRV 10 14 0 sub.server1."), - newSRV("*.region1.skydns.test. 300 IN SRV 10 14 0 unresolvable.skydns.test."), - newSRV("*.region1.skydns.test. 300 IN SRV 10 14 80 sub.server2."), - newSRV("*.region1.skydns.test. 300 IN SRV 10 14 8080 a.server1.prod.region1.skydns.test."), - newSRV("*.region1.skydns.test. 300 IN SRV 10 14 8080 b.server1.prod.region1.skydns.test."), - newSRV("*.region1.skydns.test. 300 IN SRV 10 14 8080 b.server6.prod.region1.skydns.test."), - newSRV("*.region1.skydns.test. 300 IN SRV 10 14 8080 dev.server1."), + newSRV("*.region1.skydns.test. 300 IN SRV 10 12 0 sub.server1."), + newSRV("*.region1.skydns.test. 300 IN SRV 10 12 0 unresolvable.skydns.test."), + newSRV("*.region1.skydns.test. 300 IN SRV 10 12 80 sub.server2."), + newSRV("*.region1.skydns.test. 300 IN SRV 10 12 8080 a.server1.prod.region1.skydns.test."), + newSRV("*.region1.skydns.test. 300 IN SRV 10 12 8080 b.server1.prod.region1.skydns.test."), + newSRV("*.region1.skydns.test. 300 IN SRV 10 12 8080 b.server6.prod.region1.skydns.test."), + newSRV("*.region1.skydns.test. 300 IN SRV 10 12 8080 c.sub.region1.skydns.test."), + newSRV("*.region1.skydns.test. 300 IN SRV 10 12 8080 dev.server1."), }, Extra: []dns.RR{ newA("a.server1.prod.region1.skydns.test. 300 IN A 10.0.0.1"), newA("b.server1.prod.region1.skydns.test. 300 IN A 10.0.0.2"), newAAAA("b.server6.prod.region1.skydns.test. 300 IN AAAA ::1"), + newA("c.sub.region1.skydns.test. 300 IN A 10.0.0.1"), }, }, // Wildcard Test diff --git a/middleware/etcd/setup_test.go b/middleware/etcd/setup_test.go index 838c1b93a..aa148fd7b 100644 --- a/middleware/etcd/setup_test.go +++ b/middleware/etcd/setup_test.go @@ -118,13 +118,16 @@ func TestLookup(t *testing.T) { if !checkSection(t, tc, Answer, resp.Answer) { t.Logf("%v\n", resp) + t.Fatal() } if !checkSection(t, tc, Ns, resp.Ns) { t.Logf("%v\n", resp) + t.Fatal() } if !checkSection(t, tc, Extra, resp.Extra) { t.Logf("%v\n", resp) + t.Fatal() } } } @@ -161,51 +164,51 @@ func checkSection(t *testing.T, tc dnsTestCase, sect Section, rr []dns.RR) bool for i, a := range rr { if a.Header().Name != section[i].Header().Name { - t.Errorf("answer %d should have a Header Name of %q, but has %q", i, section[i].Header().Name, a.Header().Name) + t.Errorf("rr %d should have a Header Name of %q, but has %q", i, section[i].Header().Name, a.Header().Name) return false } // 303 signals: don't care what the ttl is. if section[i].Header().Ttl != 303 && a.Header().Ttl != section[i].Header().Ttl { - t.Errorf("Answer %d should have a Header TTL of %d, but has %d", i, section[i].Header().Ttl, a.Header().Ttl) + t.Errorf("rr %d should have a Header TTL of %d, but has %d", i, section[i].Header().Ttl, a.Header().Ttl) return false } if a.Header().Rrtype != section[i].Header().Rrtype { - t.Errorf("answer %d should have a header rr type of %d, but has %d", i, section[i].Header().Rrtype, a.Header().Rrtype) + t.Errorf("rr %d should have a header rr type of %d, but has %d", i, section[i].Header().Rrtype, a.Header().Rrtype) return false } switch x := a.(type) { case *dns.SRV: if x.Priority != section[i].(*dns.SRV).Priority { - t.Errorf("answer %d should have a Priority of %d, but has %d", i, section[i].(*dns.SRV).Priority, x.Priority) + t.Errorf("rr %d should have a Priority of %d, but has %d", i, section[i].(*dns.SRV).Priority, x.Priority) return false } if x.Weight != section[i].(*dns.SRV).Weight { - t.Errorf("answer %d should have a Weight of %d, but has %d", i, section[i].(*dns.SRV).Weight, x.Weight) + t.Errorf("rr %d should have a Weight of %d, but has %d", i, section[i].(*dns.SRV).Weight, x.Weight) return false } if x.Port != section[i].(*dns.SRV).Port { - t.Errorf("answer %d should have a Port of %d, but has %d", i, section[i].(*dns.SRV).Port, x.Port) + t.Errorf("rr %d should have a Port of %d, but has %d", i, section[i].(*dns.SRV).Port, x.Port) return false } if x.Target != section[i].(*dns.SRV).Target { - t.Errorf("answer %d should have a Target of %q, but has %q", i, section[i].(*dns.SRV).Target, x.Target) + t.Errorf("rr %d should have a Target of %q, but has %q", i, section[i].(*dns.SRV).Target, x.Target) return false } case *dns.A: if x.A.String() != section[i].(*dns.A).A.String() { - t.Errorf("answer %d should have a Address of %q, but has %q", i, section[i].(*dns.A).A.String(), x.A.String()) + t.Errorf("rr %d should have a Address of %q, but has %q", i, section[i].(*dns.A).A.String(), x.A.String()) return false } case *dns.AAAA: if x.AAAA.String() != section[i].(*dns.AAAA).AAAA.String() { - t.Errorf("answer %d should have a Address of %q, but has %q", i, section[i].(*dns.AAAA).AAAA.String(), x.AAAA.String()) + t.Errorf("rr %d should have a Address of %q, but has %q", i, section[i].(*dns.AAAA).AAAA.String(), x.AAAA.String()) return false } case *dns.TXT: for j, txt := range x.Txt { if txt != section[i].(*dns.TXT).Txt[j] { - t.Errorf("answer %d should have a Txt of %q, but has %q", i, section[i].(*dns.TXT).Txt[j], txt) + t.Errorf("rr %d should have a Txt of %q, but has %q", i, section[i].(*dns.TXT).Txt[j], txt) return false } }