From d902e859199e4085cd27453f30367fd1b0799bc5 Mon Sep 17 00:00:00 2001
From: Chris O'Haver <cohaver@infoblox.com>
Date: Mon, 15 Jun 2020 10:15:41 -0400
Subject: [PATCH] plugin/kubernetes: fix tombstone unwrapping (#3924)

* fix tombstone unwrapping

Signed-off-by: Chris O'Haver <cohaver@infoblox.com>
---
 plugin/kubernetes/controller.go      |  87 ++++-----------------
 plugin/kubernetes/informer_test.go   | 111 +++++++++++++++++++++++++++
 plugin/kubernetes/object/endpoint.go |  21 ++++-
 plugin/kubernetes/object/informer.go |  27 ++++---
 4 files changed, 163 insertions(+), 83 deletions(-)
 create mode 100644 plugin/kubernetes/informer_test.go

diff --git a/plugin/kubernetes/controller.go b/plugin/kubernetes/controller.go
index 01cce28f2..90a005177 100644
--- a/plugin/kubernetes/controller.go
+++ b/plugin/kubernetes/controller.go
@@ -113,7 +113,7 @@ func newdnsController(ctx context.Context, kubeClient kubernetes.Interface, opts
 		&api.Service{},
 		cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete},
 		cache.Indexers{svcNameNamespaceIndex: svcNameNamespaceIndexFunc, svcIPIndex: svcIPIndexFunc},
-		object.DefaultProcessor(object.ToService(opts.skipAPIObjectsCleanup)),
+		object.DefaultProcessor(object.ToService(opts.skipAPIObjectsCleanup), nil),
 	)
 
 	if opts.initPodCache {
@@ -125,7 +125,7 @@ func newdnsController(ctx context.Context, kubeClient kubernetes.Interface, opts
 			&api.Pod{},
 			cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete},
 			cache.Indexers{podIPIndex: podIPIndexFunc},
-			object.DefaultProcessor(object.ToPod(opts.skipAPIObjectsCleanup)),
+			object.DefaultProcessor(object.ToPod(opts.skipAPIObjectsCleanup), nil),
 		)
 	}
 
@@ -136,73 +136,10 @@ func newdnsController(ctx context.Context, kubeClient kubernetes.Interface, opts
 				WatchFunc: endpointsWatchFunc(ctx, dns.client, api.NamespaceAll, dns.selector),
 			},
 			&api.Endpoints{},
-			cache.ResourceEventHandlerFuncs{},
+			cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete},
 			cache.Indexers{epNameNamespaceIndex: epNameNamespaceIndexFunc, epIPIndex: epIPIndexFunc},
-			func(clientState cache.Indexer, h cache.ResourceEventHandler) cache.ProcessFunc {
-				return func(obj interface{}) error {
-					for _, d := range obj.(cache.Deltas) {
-						switch d.Type {
-						case cache.Sync, cache.Added, cache.Updated:
-							apiEndpoints, ok := d.Object.(*api.Endpoints)
-							if !ok {
-								return errors.New("got non-endpoint add/update")
-							}
-							obj := object.ToEndpoints(apiEndpoints)
-
-							if old, exists, err := clientState.Get(obj); err == nil && exists {
-								if err := clientState.Update(obj); err != nil {
-									return err
-								}
-								h.OnUpdate(old, obj)
-								// endpoint updates can come frequently, make sure it's a change we care about
-								if !endpointsEquivalent(old.(*object.Endpoints), obj) {
-									dns.updateModifed()
-									recordDNSProgrammingLatency(dns.getServices(obj), apiEndpoints)
-								}
-							} else {
-								if err := clientState.Add(obj); err != nil {
-									return err
-								}
-								h.OnAdd(d.Object)
-								dns.updateModifed()
-								recordDNSProgrammingLatency(dns.getServices(obj), apiEndpoints)
-								if !opts.skipAPIObjectsCleanup {
-									*apiEndpoints = api.Endpoints{}
-								}
-							}
-						case cache.Deleted:
-							apiEndpoints, ok := d.Object.(*api.Endpoints)
-							if !ok {
-								// Assume that the object must be a cache.DeletedFinalStateUnknown.
-								// This is essentially an indicator that the Endpoint was deleted, without a containing a
-								// up-to date copy of the Endpoints object. We need to use cache.DeletedFinalStateUnknown
-								// object so it can be properly deleted by store.Delete() below, which knows how to handle it.
-								tombstone, ok := d.Object.(cache.DeletedFinalStateUnknown)
-								if !ok {
-									return errors.New("expected tombstone")
-								}
-								apiEndpoints, ok = tombstone.Obj.(*api.Endpoints)
-								if !ok {
-									return errors.New("got non-endpoint tombstone")
-								}
-							}
-							obj := object.ToEndpoints(apiEndpoints)
-
-							if err := clientState.Delete(obj); err != nil {
-								return err
-							}
-							h.OnDelete(d.Object)
-							dns.updateModifed()
-							recordDNSProgrammingLatency(dns.getServices(obj), apiEndpoints)
-							if !opts.skipAPIObjectsCleanup {
-								*apiEndpoints = api.Endpoints{}
-							}
-						}
-					}
-					return nil
-				}
-			})
-
+			object.DefaultProcessor(object.ToEndpoints(opts.skipAPIObjectsCleanup), dns.recordDNSProgrammingLatency),
+		)
 	}
 
 	dns.nsLister, dns.nsController = cache.NewInformer(
@@ -217,6 +154,10 @@ func newdnsController(ctx context.Context, kubeClient kubernetes.Interface, opts
 	return &dns
 }
 
+func (dns *dnsControl) recordDNSProgrammingLatency(obj meta.Object) {
+	recordDNSProgrammingLatency(dns.getServices(obj.(*api.Endpoints)), obj.(*api.Endpoints))
+}
+
 func podIPIndexFunc(obj interface{}) ([]string, error) {
 	p, ok := obj.(*object.Pod)
 	if !ok {
@@ -472,8 +413,8 @@ func (dns *dnsControl) GetNamespaceByName(name string) (*api.Namespace, error) {
 	return nil, fmt.Errorf("namespace not found")
 }
 
-func (dns *dnsControl) Add(obj interface{})               { dns.detectChanges(nil, obj) }
-func (dns *dnsControl) Delete(obj interface{})            { dns.detectChanges(obj, nil) }
+func (dns *dnsControl) Add(obj interface{})               { dns.updateModifed() }
+func (dns *dnsControl) Delete(obj interface{})            { dns.updateModifed() }
 func (dns *dnsControl) Update(oldObj, newObj interface{}) { dns.detectChanges(oldObj, newObj) }
 
 // detectChanges detects changes in objects, and updates the modified timestamp
@@ -491,12 +432,16 @@ func (dns *dnsControl) detectChanges(oldObj, newObj interface{}) {
 		dns.updateModifed()
 	case *object.Pod:
 		dns.updateModifed()
+	case *object.Endpoints:
+		if !endpointsEquivalent(oldObj.(*object.Endpoints), newObj.(*object.Endpoints)) {
+			dns.updateModifed()
+		}
 	default:
 		log.Warningf("Updates for %T not supported.", ob)
 	}
 }
 
-func (dns *dnsControl) getServices(endpoints *object.Endpoints) []*object.Service {
+func (dns *dnsControl) getServices(endpoints *api.Endpoints) []*object.Service {
 	return dns.SvcIndex(object.EndpointsKey(endpoints.GetName(), endpoints.GetNamespace()))
 }
 
diff --git a/plugin/kubernetes/informer_test.go b/plugin/kubernetes/informer_test.go
new file mode 100644
index 000000000..d47d6feee
--- /dev/null
+++ b/plugin/kubernetes/informer_test.go
@@ -0,0 +1,111 @@
+package kubernetes
+
+import (
+	"testing"
+
+	"github.com/coredns/coredns/plugin/kubernetes/object"
+
+	api "k8s.io/api/core/v1"
+	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+	"k8s.io/client-go/tools/cache"
+)
+
+func TestDefaultProcessor(t *testing.T) {
+	pbuild := object.DefaultProcessor(object.ToService(true), nil)
+	reh := cache.ResourceEventHandlerFuncs{}
+	idx := cache.NewIndexer(cache.DeletionHandlingMetaNamespaceKeyFunc, cache.Indexers{})
+	processor := pbuild(idx, reh)
+	testProcessor(t, processor, idx)
+}
+
+func testProcessor(t *testing.T, processor cache.ProcessFunc, idx cache.Indexer) {
+	obj := &api.Service{
+		ObjectMeta: metav1.ObjectMeta{Name: "service1", Namespace: "test1"},
+		Spec:       api.ServiceSpec{ClusterIP: "1.2.3.4", Ports: []api.ServicePort{{Port: 80}}},
+	}
+	obj2 := &api.Service{
+		ObjectMeta: metav1.ObjectMeta{Name: "service2", Namespace: "test1"},
+		Spec:       api.ServiceSpec{ClusterIP: "5.6.7.8", Ports: []api.ServicePort{{Port: 80}}},
+	}
+
+	// Add the objects
+	err := processor(cache.Deltas{
+		{Type: cache.Added, Object: obj},
+		{Type: cache.Added, Object: obj2},
+	})
+	if err != nil {
+		t.Fatalf("add failed: %v", err)
+	}
+	got, exists, err := idx.Get(obj)
+	if err != nil {
+		t.Fatalf("get added object failed: %v", err)
+	}
+	if !exists {
+		t.Fatal("added object not found in index")
+	}
+	svc, ok := got.(*object.Service)
+	if !ok {
+		t.Fatal("object in index was incorrect type")
+	}
+	if svc.ClusterIP != obj.Spec.ClusterIP {
+		t.Fatalf("expected %v, got %v", obj.Spec.ClusterIP, svc.ClusterIP)
+	}
+
+	// Update an object
+	obj.Spec.ClusterIP = "1.2.3.5"
+	err = processor(cache.Deltas{{
+		Type:   cache.Updated,
+		Object: obj,
+	}})
+	if err != nil {
+		t.Fatalf("update failed: %v", err)
+	}
+	got, exists, err = idx.Get(obj)
+	if err != nil {
+		t.Fatalf("get updated object failed: %v", err)
+	}
+	if !exists {
+		t.Fatal("updated object not found in index")
+	}
+	svc, ok = got.(*object.Service)
+	if !ok {
+		t.Fatal("object in index was incorrect type")
+	}
+	if svc.ClusterIP != obj.Spec.ClusterIP {
+		t.Fatalf("expected %v, got %v", obj.Spec.ClusterIP, svc.ClusterIP)
+	}
+
+	// Delete an object
+	err = processor(cache.Deltas{{
+		Type:   cache.Deleted,
+		Object: obj2,
+	}})
+	if err != nil {
+		t.Fatalf("delete test failed: %v", err)
+	}
+	got, exists, err = idx.Get(obj2)
+	if err != nil {
+		t.Fatalf("get deleted object failed: %v", err)
+	}
+	if exists {
+		t.Fatal("deleted object found in index")
+	}
+
+	// Delete an object via tombstone
+	key, _ := cache.MetaNamespaceKeyFunc(obj)
+	tombstone := cache.DeletedFinalStateUnknown{Key: key, Obj: svc}
+	err = processor(cache.Deltas{{
+		Type:   cache.Deleted,
+		Object: tombstone,
+	}})
+	if err != nil {
+		t.Fatalf("tombstone delete test failed: %v", err)
+	}
+	got, exists, err = idx.Get(svc)
+	if err != nil {
+		t.Fatalf("get tombstone deleted object failed: %v", err)
+	}
+	if exists {
+		t.Fatal("tombstone deleted object found in index")
+	}
+}
diff --git a/plugin/kubernetes/object/endpoint.go b/plugin/kubernetes/object/endpoint.go
index 2a7d69acf..f3ce9c2d6 100644
--- a/plugin/kubernetes/object/endpoint.go
+++ b/plugin/kubernetes/object/endpoint.go
@@ -1,6 +1,8 @@
 package object
 
 import (
+	"fmt"
+
 	api "k8s.io/api/core/v1"
 	"k8s.io/apimachinery/pkg/runtime"
 )
@@ -43,8 +45,19 @@ type EndpointPort struct {
 // EndpointsKey return a string using for the index.
 func EndpointsKey(name, namespace string) string { return name + "." + namespace }
 
-// ToEndpoints converts an api.Endpoints to a *Endpoints.
-func ToEndpoints(end *api.Endpoints) *Endpoints {
+// ToEndpoints returns a function that converts an *api.Endpoints to a *Endpoints.
+func ToEndpoints(skipCleanup bool) ToFunc {
+	return func(obj interface{}) (interface{}, error) {
+		eps, ok := obj.(*api.Endpoints)
+		if !ok {
+			return nil, fmt.Errorf("unexpected object %v", obj)
+		}
+		return toEndpoints(skipCleanup, eps), nil
+	}
+}
+
+// toEndpoints converts an *api.Endpoints to a *Endpoints.
+func toEndpoints(skipCleanup bool, end *api.Endpoints) *Endpoints {
 	e := &Endpoints{
 		Version:   end.GetResourceVersion(),
 		Name:      end.GetName(),
@@ -88,6 +101,10 @@ func ToEndpoints(end *api.Endpoints) *Endpoints {
 		}
 	}
 
+	if !skipCleanup {
+		*end = api.Endpoints{}
+	}
+
 	return e
 }
 
diff --git a/plugin/kubernetes/object/informer.go b/plugin/kubernetes/object/informer.go
index e0d7f180c..f37af4796 100644
--- a/plugin/kubernetes/object/informer.go
+++ b/plugin/kubernetes/object/informer.go
@@ -1,6 +1,7 @@
 package object
 
 import (
+	meta "k8s.io/apimachinery/pkg/apis/meta/v1"
 	"k8s.io/apimachinery/pkg/runtime"
 	"k8s.io/client-go/tools/cache"
 )
@@ -20,8 +21,10 @@ func NewIndexerInformer(lw cache.ListerWatcher, objType runtime.Object, h cache.
 	return clientState, cache.New(cfg)
 }
 
-// DefaultProcessor is a copy of Process function from cache.NewIndexerInformer except it does a conversion.
-func DefaultProcessor(convert ToFunc) ProcessorBuilder {
+type recordLatencyFunc func(meta.Object)
+
+// DefaultProcessor is based on the Process function from cache.NewIndexerInformer except it does a conversion.
+func DefaultProcessor(convert ToFunc, recordLatency recordLatencyFunc) ProcessorBuilder {
 	return func(clientState cache.Indexer, h cache.ResourceEventHandler) cache.ProcessFunc {
 		return func(obj interface{}) error {
 			for _, d := range obj.(cache.Deltas) {
@@ -42,23 +45,27 @@ func DefaultProcessor(convert ToFunc) ProcessorBuilder {
 						}
 						h.OnAdd(obj)
 					}
+					if recordLatency != nil {
+						recordLatency(d.Object.(meta.Object))
+					}
 				case cache.Deleted:
 					var obj interface{}
-					var err error
-					tombstone, ok := d.Object.(cache.DeletedFinalStateUnknown)
-					if ok {
-						obj, err = convert(tombstone.Obj)
-					} else {
+					obj, ok := d.Object.(cache.DeletedFinalStateUnknown)
+					if !ok {
+						var err error
 						obj, err = convert(d.Object)
-					}
-					if err != nil && err != errPodTerminating {
-						return err
+						if err != nil && err != errPodTerminating {
+							return err
+						}
 					}
 
 					if err := clientState.Delete(obj); err != nil {
 						return err
 					}
 					h.OnDelete(obj)
+					if !ok && recordLatency != nil {
+						recordLatency(d.Object.(meta.Object))
+					}
 				}
 			}
 			return nil