dont require/allow "_" prefix for srv wildcard fields (#472)

* dont require/allow "_" prefix for srv wildcard fields

* streamline parse/validation of req name

* removing nametemplate

* error when zone not found, loopify unit tests
This commit is contained in:
Chris O'Haver 2017-01-15 03:12:28 -05:00 committed by Miek Gieben
parent b6a2a5aeaa
commit a6d232a622
9 changed files with 245 additions and 600 deletions

View file

@ -10,7 +10,6 @@ import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/kubernetes/nametemplate"
"github.com/miekg/coredns/middleware/pkg/dnsutil"
dnsstrings "github.com/miekg/coredns/middleware/pkg/strings"
"github.com/miekg/coredns/middleware/proxy"
@ -38,7 +37,6 @@ type Kubernetes struct {
APIClientKey string
APIConn *dnsController
ResyncPeriod time.Duration
NameTemplate *nametemplate.Template
Namespaces []string
LabelSelector *unversionedapi.LabelSelector
Selector *labels.Selector
@ -69,16 +67,22 @@ type pod struct {
addr string
}
type recordRequest struct {
port, protocol, endpoint, service, namespace, typeName, zone string
}
var errNoItems = errors.New("no items found")
var errNsNotExposed = errors.New("namespace is not exposed")
var errInvalidRequest = errors.New("invalid query name")
// Services implements the ServiceBackend interface.
func (k *Kubernetes) Services(state request.Request, exact bool, opt middleware.Options) ([]msg.Service, []msg.Service, error) {
if state.Type() == "SRV" && !ValidSRV(state.Name()) {
return nil, nil, errInvalidRequest
r, e := k.parseRequest(state.Name(), state.Type())
if e != nil {
return nil, nil, e
}
s, e := k.Records(state.Name(), exact)
s, e := k.Records(r)
return s, nil, e // Haven't implemented debug queries yet.
}
@ -177,85 +181,94 @@ func (k *Kubernetes) InitKubeCache() error {
return err
}
// getZoneForName returns the zone string that matches the name and a
// list of the DNS labels from name that are within the zone.
// For example, if "coredns.local" is a zone configured for the
// Kubernetes middleware, then getZoneForName("a.b.coredns.local")
// will return ("coredns.local", ["a", "b"]).
func (k *Kubernetes) getZoneForName(name string) (string, []string) {
var zone string
var serviceSegments []string
func (k *Kubernetes) parseRequest(lowerCasedName, qtype string) (r recordRequest, err error) {
// 3 Possible cases
// SRV Request: _port._protocol.service.namespace.type.zone
// A Request (endpoint): endpoint.service.namespace.type.zone
// A Request (service): service.namespace.type.zone
// separate zone from rest of lowerCasedName
var segs []string
for _, z := range k.Zones {
if dns.IsSubDomain(z, name) {
zone = z
if dns.IsSubDomain(z, lowerCasedName) {
r.zone = z
serviceSegments = dns.SplitDomainName(name)
serviceSegments = serviceSegments[:len(serviceSegments)-dns.CountLabel(zone)]
segs = dns.SplitDomainName(lowerCasedName)
segs = segs[:len(segs)-dns.CountLabel(r.zone)]
break
}
}
return zone, serviceSegments
}
// stripSRVPrefix separates out the port and protocol segments, if present
// If not present, assume all ports/protocols (e.g. wildcard)
func stripSRVPrefix(name []string) (string, string, []string) {
if name[0][0] == '_' && name[1][0] == '_' {
return name[0][1:], name[1][1:], name[2:]
if r.zone == "" {
return r, errors.New("zone not found")
}
// no srv prefix present
return "*", "*", name
}
func stripEndpointName(name []string) (endpoint string, nameOut []string) {
if len(name) == 4 {
return strings.ToLower(name[0]), name[1:]
offset := 0
if len(segs) == 5 {
// This is a SRV style request, get first two elements as port and
// protocol, stripping leading underscores if present.
if segs[0][0] == '_' {
r.port = segs[0][1:]
} else {
r.port = segs[0]
if !symbolContainsWildcard(r.port) {
return r, errors.New("srv port must start with an underscore or be a wildcard")
}
}
if segs[1][0] == '_' {
r.protocol = segs[1][1:]
if r.protocol != "tcp" && r.protocol != "udp" {
return r, errors.New("invalid srv protocol: " + r.protocol)
}
} else {
r.protocol = segs[1]
if !symbolContainsWildcard(r.protocol) {
return r, errors.New("srv protocol must start with an underscore or be a wildcard")
}
}
offset = 2
} else if len(segs) == 4 {
// This is an endpoint A style request. Get first element as endpoint.
r.endpoint = segs[0]
offset = 1
}
return "", name
// SRV requests require a port and protocol
if qtype == "SRV" {
if r.port == "" || r.protocol == "" {
return r, errors.New("invalid srv request")
}
}
// A requests cannot have port/protocol
if qtype == "A" {
if r.port != "" && r.protocol != "" {
return r, errors.New("invalid a request")
}
}
if len(segs) == (offset + 3) {
r.service = segs[offset]
r.namespace = segs[offset+1]
r.typeName = segs[offset+2]
return r, nil
}
return r, errors.New("invalid request")
}
// Records looks up services in kubernetes. If exact is true, it will lookup
// just this name. This is used when find matches when completing SRV lookups
// for instance.
func (k *Kubernetes) Records(name string, exact bool) ([]msg.Service, error) {
var (
serviceName string
namespace string
typeName string
)
zone, serviceSegments := k.getZoneForName(name)
port, protocol, serviceSegments := stripSRVPrefix(serviceSegments)
endpointname, serviceSegments := stripEndpointName(serviceSegments)
if len(serviceSegments) < 3 {
return nil, errNoItems
}
serviceName = serviceSegments[0]
namespace = serviceSegments[1]
typeName = serviceSegments[2]
if namespace == "" {
err := errors.New("Parsing query string did not produce a namespace value. Assuming wildcard namespace.")
log.Printf("[WARN] %v\n", err)
namespace = "*"
}
if serviceName == "" {
err := errors.New("Parsing query string did not produce a serviceName value. Assuming wildcard serviceName.")
log.Printf("[WARN] %v\n", err)
serviceName = "*"
}
func (k *Kubernetes) Records(r recordRequest) ([]msg.Service, error) {
// Abort if the namespace does not contain a wildcard, and namespace is not published per CoreFile
// Case where namespace contains a wildcard is handled in Get(...) method.
if (!symbolContainsWildcard(namespace)) && (len(k.Namespaces) > 0) && (!dnsstrings.StringInSlice(namespace, k.Namespaces)) {
if (!symbolContainsWildcard(r.namespace)) && (len(k.Namespaces) > 0) && (!dnsstrings.StringInSlice(r.namespace, k.Namespaces)) {
return nil, errNsNotExposed
}
services, pods, err := k.Get(namespace, serviceName, endpointname, port, protocol, typeName)
services, pods, err := k.Get(r)
if err != nil {
return nil, err
}
@ -264,7 +277,7 @@ func (k *Kubernetes) Records(name string, exact bool) ([]msg.Service, error) {
return nil, errNoItems
}
records := k.getRecordsForK8sItems(services, pods, zone)
records := k.getRecordsForK8sItems(services, pods, r.zone)
return records, nil
}
@ -320,18 +333,6 @@ func (k *Kubernetes) getRecordsForK8sItems(services []service, pods []pod, zone
return records
}
// Get retrieves matching data from the cache.
func (k *Kubernetes) Get(namespace, servicename, endpointname, port, protocol, typeName string) (services []service, pods []pod, err error) {
switch {
case typeName == "pod":
pods, err = k.findPods(namespace, servicename)
return nil, pods, err
default:
services, err = k.findServices(namespace, servicename, endpointname, port, protocol)
return services, nil, err
}
}
func ipFromPodName(podname string) string {
if strings.Count(podname, "-") == 3 && !strings.Contains(podname, "--") {
return strings.Replace(podname, "-", ".", -1)
@ -362,18 +363,30 @@ func (k *Kubernetes) findPods(namespace, podname string) (pods []pod, err error)
}
func (k *Kubernetes) findServices(namespace, servicename, endpointname, port, protocol string) ([]service, error) {
// Get retrieves matching data from the cache.
func (k *Kubernetes) Get(r recordRequest) (services []service, pods []pod, err error) {
switch {
case r.typeName == "pod":
pods, err = k.findPods(r.namespace, r.service)
return nil, pods, err
default:
services, err = k.findServices(r)
return services, nil, err
}
}
func (k *Kubernetes) findServices(r recordRequest) ([]service, error) {
serviceList := k.APIConn.ServiceList()
var resultItems []service
nsWildcard := symbolContainsWildcard(namespace)
serviceWildcard := symbolContainsWildcard(servicename)
portWildcard := symbolContainsWildcard(port)
protocolWildcard := symbolContainsWildcard(protocol)
nsWildcard := symbolContainsWildcard(r.namespace)
serviceWildcard := symbolContainsWildcard(r.service)
portWildcard := symbolContainsWildcard(r.port) || r.port == ""
protocolWildcard := symbolContainsWildcard(r.protocol) || r.protocol == ""
for _, svc := range serviceList {
if !(symbolMatches(namespace, svc.Namespace, nsWildcard) && symbolMatches(servicename, svc.Name, serviceWildcard)) {
if !(symbolMatches(r.namespace, svc.Namespace, nsWildcard) && symbolMatches(r.service, svc.Name, serviceWildcard)) {
continue
}
// If namespace has a wildcard, filter results against Corefile namespace list.
@ -384,7 +397,7 @@ func (k *Kubernetes) findServices(namespace, servicename, endpointname, port, pr
s := service{name: svc.Name, namespace: svc.Namespace, addr: svc.Spec.ClusterIP}
if s.addr != api.ClusterIPNone {
for _, p := range svc.Spec.Ports {
if !(symbolMatches(port, strings.ToLower(p.Name), portWildcard) && symbolMatches(protocol, strings.ToLower(string(p.Protocol)), protocolWildcard)) {
if !(symbolMatches(r.port, strings.ToLower(p.Name), portWildcard) && symbolMatches(r.protocol, strings.ToLower(string(p.Protocol)), protocolWildcard)) {
continue
}
s.ports = append(s.ports, p)
@ -405,10 +418,10 @@ func (k *Kubernetes) findServices(namespace, servicename, endpointname, port, pr
for _, addr := range eps.Addresses {
for _, p := range eps.Ports {
ephostname := endpointHostname(addr)
if endpointname != "" && endpointname != ephostname {
if r.endpoint != "" && r.endpoint != ephostname {
continue
}
if !(symbolMatches(port, strings.ToLower(p.Name), portWildcard) && symbolMatches(protocol, strings.ToLower(string(p.Protocol)), protocolWildcard)) {
if !(symbolMatches(r.port, strings.ToLower(p.Name), portWildcard) && symbolMatches(r.protocol, strings.ToLower(string(p.Protocol)), protocolWildcard)) {
continue
}
s.endpoints = append(s.endpoints, endpoint{addr: addr, port: p})
@ -422,16 +435,10 @@ func (k *Kubernetes) findServices(namespace, servicename, endpointname, port, pr
}
func symbolMatches(queryString, candidateString string, wildcard bool) bool {
result := false
switch {
case !wildcard:
result = (queryString == candidateString)
case queryString == "*":
result = true
case queryString == "any":
result = true
if wildcard {
return true
}
return result
return queryString == candidateString
}
// getServiceRecordForIP: Gets a service record with a cluster ip matching the ip argument
@ -476,57 +483,3 @@ func (k *Kubernetes) getServiceRecordForIP(ip, name string) []msg.Service {
func symbolContainsWildcard(symbol string) bool {
return (strings.Contains(symbol, "*") || (symbol == "any"))
}
// ValidSRV parses a server record validating _port._proto. prefix labels.
// The valid schema is:
// * Fist two segments must start with an "_",
// * Second segment must be one of _tcp|_udp|_*|_any
func ValidSRV(name string) bool {
// Does it start with a "_" ?
if len(name) > 0 && name[0] != '_' {
return false
}
// First label
first, end := dns.NextLabel(name, 0)
if end {
return false
}
// Second label
off, end := dns.NextLabel(name, first)
if end {
return false
}
// first:off has captured _tcp. or _udp. (if present)
second := name[first:off]
if len(second) > 0 && second[0] != '_' {
return false
}
// A bit convoluted to avoid strings.ToLower
if len(second) == 5 {
// matches _tcp
if (second[1] == 't' || second[1] == 'T') && (second[2] == 'c' || second[2] == 'C') &&
(second[3] == 'p' || second[3] == 'P') {
return true
}
// matches _udp
if (second[1] == 'u' || second[1] == 'U') && (second[2] == 'd' || second[2] == 'D') &&
(second[3] == 'p' || second[3] == 'P') {
return true
}
// matches _any
if (second[1] == 'a' || second[1] == 'A') && (second[2] == 'n' || second[2] == 'N') &&
(second[3] == 'y' || second[3] == 'Y') {
return true
}
}
// matches _*
if len(second) == 3 && second[1] == '*' {
return true
}
return false
}