diff --git a/middleware/kubernetes/kubernetes.go b/middleware/kubernetes/kubernetes.go index 0d8c68976..af2de43f0 100644 --- a/middleware/kubernetes/kubernetes.go +++ b/middleware/kubernetes/kubernetes.go @@ -225,7 +225,10 @@ func (k *Kubernetes) parseRequest(lowerCasedName, qtype string) (r recordRequest } offset := 0 - if len(segs) == 5 { + if qtype == "SRV" { + if len(segs) != 5 { + return r, errInvalidRequest + } // This is a SRV style request, get first two elements as port and // protocol, stripping leading underscores if present. if segs[0][0] == '_' { @@ -233,40 +236,31 @@ func (k *Kubernetes) parseRequest(lowerCasedName, qtype string) (r recordRequest } else { r.port = segs[0] if !symbolContainsWildcard(r.port) { - return r, errors.New("srv port must start with an underscore or be a wildcard") + return r, errInvalidRequest } } if segs[1][0] == '_' { r.protocol = segs[1][1:] if r.protocol != "tcp" && r.protocol != "udp" { - return r, errors.New("invalid srv protocol: " + r.protocol) + return r, errInvalidRequest } } else { r.protocol = segs[1] if !symbolContainsWildcard(r.protocol) { - return r, errors.New("srv protocol must start with an underscore or be a wildcard") + return r, errInvalidRequest } } + if r.port == "" || r.protocol == "" { + return r, errInvalidRequest + } offset = 2 - } else if len(segs) == 4 { - // This is an endpoint A style request. Get first element as endpoint. + } + if qtype == "A" && len(segs) == 4 { + // This is an endpoint A record request. Get first element as endpoint. r.endpoint = segs[0] offset = 1 } - // SRV requests require a port and protocol - if qtype == "SRV" { - if r.port == "" || r.protocol == "" { - return r, errors.New("invalid srv request") - } - } - // A requests cannot have port/protocol - if qtype == "A" { - if r.port != "" && r.protocol != "" { - return r, errors.New("invalid a request") - } - } - if len(segs) == (offset + 3) { r.service = segs[offset] r.namespace = segs[offset+1] @@ -280,7 +274,7 @@ func (k *Kubernetes) parseRequest(lowerCasedName, qtype string) (r recordRequest return r, nil } - return r, errors.New("invalid request") + return r, errInvalidRequest } diff --git a/test/kubernetes_test.go b/test/kubernetes_test.go index cdb8add2b..9e6be8d28 100644 --- a/test/kubernetes_test.go +++ b/test/kubernetes_test.go @@ -181,17 +181,17 @@ var dnsTestCases = []test.Case{ }, { Qname: "*.svc-1-a.test-1.svc.cluster.local.", Qtype: dns.TypeSRV, - Rcode: dns.RcodeServerFailure, + Rcode: dns.RcodeNameError, Answer: []dns.RR{}, }, { Qname: "*._not-udp-or-tcp.svc-1-a.test-1.svc.cluster.local.", Qtype: dns.TypeSRV, - Rcode: dns.RcodeServerFailure, + Rcode: dns.RcodeNameError, Answer: []dns.RR{}, }, { Qname: "svc-1-a.test-1.svc.cluster.local.", Qtype: dns.TypeSRV, - Rcode: dns.RcodeServerFailure, + Rcode: dns.RcodeNameError, Answer: []dns.RR{}, }, {