diff --git a/middleware/backend_lookup.go b/middleware/backend_lookup.go index 8b86495bb..41968c2fa 100644 --- a/middleware/backend_lookup.go +++ b/middleware/backend_lookup.go @@ -21,9 +21,11 @@ func A(b ServiceBackend, zone string, state request.Request, previousRecords []d } for _, serv := range services { - ip := net.ParseIP(serv.Host) - switch { - case ip == nil: + + what, ip := serv.HostType() + + switch what { + case dns.TypeANY: if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) { // x CNAME x is a direct loop, don't add those continue @@ -67,9 +69,11 @@ func A(b ServiceBackend, zone string, state request.Request, previousRecords []d records = append(records, newRecord) records = append(records, m1.Answer...) continue - case ip.To4() != nil: - records = append(records, serv.NewA(state.QName(), ip.To4())) - case ip.To4() == nil: + + case dns.TypeA: + records = append(records, serv.NewA(state.QName(), ip)) + + case dns.TypeAAAA: // nodata? } } @@ -84,9 +88,11 @@ func AAAA(b ServiceBackend, zone string, state request.Request, previousRecords } for _, serv := range services { - ip := net.ParseIP(serv.Host) - switch { - case ip == nil: + + what, ip := serv.HostType() + + switch what { + case dns.TypeANY: // Try to resolve as CNAME if it's not an IP, but only if we don't create loops. if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) { // x CNAME x is a direct loop, don't add those @@ -131,10 +137,12 @@ func AAAA(b ServiceBackend, zone string, state request.Request, previousRecords records = append(records, m1.Answer...) continue // both here again - case ip.To4() != nil: + + case dns.TypeA: // nada? - case ip.To4() == nil: - records = append(records, serv.NewAAAA(state.QName(), ip.To16())) + + case dns.TypeAAAA: + records = append(records, serv.NewAAAA(state.QName(), ip)) } } return records, debug, nil @@ -170,9 +178,11 @@ func SRV(b ServiceBackend, zone string, state request.Request, opt Options) (rec w1 *= float64(serv.Weight) } weight := uint16(math.Floor(w1)) - ip := net.ParseIP(serv.Host) - switch { - case ip == nil: + + what, ip := serv.HostType() + + switch what { + case dns.TypeANY: srv := serv.NewSRV(state.QName(), weight) records = append(records, srv) @@ -214,18 +224,13 @@ func SRV(b ServiceBackend, zone string, state request.Request, opt Options) (rec debug = append(debug, debugAddr...) } // IPv6 lookups here as well? AAAA(zone, state1, nil). - case ip.To4() != nil: + + case dns.TypeA, dns.TypeAAAA: serv.Host = msg.Domain(serv.Key) srv := serv.NewSRV(state.QName(), weight) records = append(records, srv) - extra = append(extra, serv.NewA(srv.Target, ip.To4())) - case ip.To4() == nil: - serv.Host = msg.Domain(serv.Key) - srv := serv.NewSRV(state.QName(), weight) - - records = append(records, srv) - extra = append(extra, serv.NewAAAA(srv.Target, ip.To16())) + extra = append(extra, newAddress(serv, srv.Target, ip, what)) } } return records, extra, debug, nil @@ -243,9 +248,9 @@ func MX(b ServiceBackend, zone string, state request.Request, opt Options) (reco if !serv.Mail { continue } - ip := net.ParseIP(serv.Host) - switch { - case ip == nil: + what, ip := serv.HostType() + switch what { + case dns.TypeANY: mx := serv.NewMX(state.QName()) records = append(records, mx) if _, ok := lookup[mx.Mx]; ok { @@ -284,14 +289,11 @@ func MX(b ServiceBackend, zone string, state request.Request, opt Options) (reco debug = append(debug, debugAddr...) } // e.AAAA as well - case ip.To4() != nil: + + case dns.TypeA, dns.TypeAAAA: serv.Host = msg.Domain(serv.Key) records = append(records, serv.NewMX(state.QName())) - extra = append(extra, serv.NewA(serv.Host, ip.To4())) - case ip.To4() == nil: - serv.Host = msg.Domain(serv.Key) - records = append(records, serv.NewMX(state.QName())) - extra = append(extra, serv.NewAAAA(serv.Host, ip.To16())) + extra = append(extra, newAddress(serv, serv.Host, ip, what)) } } return records, extra, debug, nil @@ -360,18 +362,15 @@ func NS(b ServiceBackend, zone string, state request.Request, opt Options) (reco state.Req.Question[0].Name = old for _, serv := range services { - ip := net.ParseIP(serv.Host) - switch { - case ip == nil: + what, ip := serv.HostType() + switch what { + case dns.TypeANY: return nil, nil, debug, fmt.Errorf("NS record must be an IP address: %s", serv.Host) - case ip.To4() != nil: + + case dns.TypeA, dns.TypeAAAA: serv.Host = msg.Domain(serv.Key) records = append(records, serv.NewNS(state.QName())) - extra = append(extra, serv.NewA(serv.Host, ip.To4())) - case ip.To4() == nil: - serv.Host = msg.Domain(serv.Key) - records = append(records, serv.NewNS(state.QName())) - extra = append(extra, serv.NewAAAA(serv.Host, ip.To16())) + extra = append(extra, newAddress(serv, serv.Host, ip, what)) } } return records, extra, debug, nil @@ -445,6 +444,17 @@ func ErrorToTxt(err error) dns.RR { return t } +func newAddress(s msg.Service, name string, ip net.IP, what uint16) dns.RR { + + hdr := dns.RR_Header{Name: name, Rrtype: what, Class: dns.ClassINET, Ttl: s.TTL} + + if what == dns.TypeA { + return &dns.A{Hdr: hdr, A: ip} + } + // Should always be dns.TypeAAAA + return &dns.AAAA{Hdr: hdr, AAAA: ip} +} + const ( minTTL = 60 hostmaster = "hostmaster" diff --git a/middleware/etcd/msg/type.go b/middleware/etcd/msg/type.go new file mode 100644 index 000000000..807e7b471 --- /dev/null +++ b/middleware/etcd/msg/type.go @@ -0,0 +1,33 @@ +package msg + +import ( + "net" + + "github.com/miekg/dns" +) + +// HostType returns the DNS type of what is encoded in the Service Host field. We're reusing +// dns.TypeXXX to not reinvent a new set of identifiers. +// +// dns.TypeA: the service's Host field contains an A record. +// dns.TypeAAAA: the service's Host field contains an AAAA record. +// dns.TypeANY: the service's Host field contains a name. +// +// Note that a service can double/triple as a TXT record or MX record. +func (s *Service) HostType() (what uint16, normalized net.IP) { + + ip := net.ParseIP(s.Host) + + switch { + case ip == nil: + return dns.TypeANY, nil + + case ip.To4() != nil: + return dns.TypeA, ip.To4() + + case ip.To4() == nil: + return dns.TypeAAAA, ip.To16() + } + // This should never be reached. + return dns.TypeNone, nil +} diff --git a/middleware/etcd/msg/type_test.go b/middleware/etcd/msg/type_test.go new file mode 100644 index 000000000..5fef74b23 --- /dev/null +++ b/middleware/etcd/msg/type_test.go @@ -0,0 +1,31 @@ +package msg + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestType(t *testing.T) { + tests := []struct { + serv Service + expectedType uint16 + }{ + {Service{Host: "example.org"}, dns.TypeANY}, + {Service{Host: "127.0.0.1"}, dns.TypeA}, + {Service{Host: "2000::3"}, dns.TypeAAAA}, + {Service{Host: "2000..3"}, dns.TypeANY}, + {Service{Host: "127.0.0.257"}, dns.TypeANY}, + {Service{Host: "127.0.0.252", Mail: true}, dns.TypeA}, + {Service{Host: "127.0.0.252", Mail: true, Text: "a"}, dns.TypeA}, + {Service{Host: "127.0.0.254", Mail: false, Text: "a"}, dns.TypeA}, + } + + for i, tc := range tests { + what, _ := tc.serv.HostType() + if what != tc.expectedType { + t.Errorf("Test %d: Expected what %v, but got %v", i, tc.expectedType, what) + } + } + +}