From 135377bf776295d8ef86081c1ef581e7b41d26f0 Mon Sep 17 00:00:00 2001
From: Ruslan Drozhdzh <30860269+rdrozhdzh@users.noreply.github.com>
Date: Fri, 20 Apr 2018 17:47:46 +0300
Subject: [PATCH] plugin/forward: gracefull stop (#1701)

* plugin/forward: gracefull stop

 - stop connection manager only when no queries in progress

* minor improvement

* prevent healthcheck on stopped proxy

* revert closing channels

* use standard context
---
 plugin/forward/connect.go    | 11 +++++-
 plugin/forward/forward.go    |  3 +-
 plugin/forward/proxy.go      | 26 +++++++++++++--
 plugin/forward/proxy_test.go | 65 ++++++++++++++++++++++++++++++++++++
 4 files changed, 100 insertions(+), 5 deletions(-)
 create mode 100644 plugin/forward/proxy_test.go

diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go
index 0a66f2752..6ea7913e5 100644
--- a/plugin/forward/connect.go
+++ b/plugin/forward/connect.go
@@ -35,6 +35,16 @@ func (p *Proxy) updateRtt(newRtt time.Duration) {
 }
 
 func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
+	atomic.AddInt32(&p.inProgress, 1)
+	defer func() {
+		if atomic.AddInt32(&p.inProgress, -1) == 0 {
+			p.checkStopTransport()
+		}
+	}()
+	if atomic.LoadUint32(&p.state) != running {
+		return nil, errStopped
+	}
+
 	start := time.Now()
 
 	proto := state.Proto()
@@ -46,7 +56,6 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
 	if err != nil {
 		return nil, err
 	}
-
 	// Set buffer size correctly for this client.
 	conn.UDPSize = uint16(state.Size())
 	if conn.UDPSize < 512 {
diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go
index 153c5ab38..213b30f8b 100644
--- a/plugin/forward/forward.go
+++ b/plugin/forward/forward.go
@@ -120,7 +120,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
 
 		if err != nil {
 			// Kick off health check to see if *our* upstream is broken.
-			if f.maxfails != 0 {
+			if f.maxfails != 0 && err != errStopped {
 				proxy.Healthcheck()
 			}
 
@@ -186,6 +186,7 @@ var (
 	errNoHealthy     = errors.New("no healthy proxies")
 	errNoForward     = errors.New("no forwarder defined")
 	errCachedClosed  = errors.New("cached connection was closed by peer")
+	errStopped       = errors.New("proxy has been stopped")
 )
 
 // policy tells forward what policy for selecting upstream it uses.
diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go
index 3271e7dd9..8454b296d 100644
--- a/plugin/forward/proxy.go
+++ b/plugin/forward/proxy.go
@@ -24,6 +24,9 @@ type Proxy struct {
 	fails uint32
 
 	avgRtt int64
+
+	state      uint32
+	inProgress int32
 }
 
 // NewProxy returns a new proxy.
@@ -79,15 +82,26 @@ func (p *Proxy) Down(maxfails uint32) bool {
 	return fails > maxfails
 }
 
-// close stops the health checking goroutine.
+// close stops the health checking goroutine and connection manager.
 func (p *Proxy) close() {
-	p.probe.Stop()
-	p.transport.Stop()
+	if atomic.CompareAndSwapUint32(&p.state, running, stopping) {
+		p.probe.Stop()
+	}
+	if atomic.LoadInt32(&p.inProgress) == 0 {
+		p.checkStopTransport()
+	}
 }
 
 // start starts the proxy's healthchecking.
 func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) }
 
+// checkStopTransport checks if stop was requested and stops connection manager
+func (p *Proxy) checkStopTransport() {
+	if atomic.CompareAndSwapUint32(&p.state, stopping, stopped) {
+		p.transport.Stop()
+	}
+}
+
 const (
 	dialTimeout = 4 * time.Second
 	timeout     = 2 * time.Second
@@ -95,3 +109,9 @@ const (
 	minTimeout  = 10 * time.Millisecond
 	hcDuration  = 500 * time.Millisecond
 )
+
+const (
+	running = iota
+	stopping
+	stopped
+)
diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go
new file mode 100644
index 000000000..8c53f3150
--- /dev/null
+++ b/plugin/forward/proxy_test.go
@@ -0,0 +1,65 @@
+package forward
+
+import (
+	"context"
+	"sync"
+	"testing"
+
+	"github.com/coredns/coredns/plugin/pkg/dnstest"
+	"github.com/coredns/coredns/plugin/test"
+	"github.com/coredns/coredns/request"
+
+	"github.com/miekg/dns"
+)
+
+func TestProxyClose(t *testing.T) {
+	s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
+		ret := new(dns.Msg)
+		ret.SetReply(r)
+		w.WriteMsg(ret)
+	})
+	defer s.Close()
+
+	req := new(dns.Msg)
+	req.SetQuestion("example.org.", dns.TypeA)
+	state := request.Request{W: &test.ResponseWriter{}, Req: req}
+	ctx := context.TODO()
+
+	repeatCnt := 1000
+	for repeatCnt > 0 {
+		repeatCnt--
+		p := NewProxy(s.Addr, nil /* no TLS */)
+		p.start(hcDuration)
+
+		var wg sync.WaitGroup
+		wg.Add(5)
+		go func() {
+			p.connect(ctx, state, false, false)
+			wg.Done()
+		}()
+		go func() {
+			p.connect(ctx, state, true, false)
+			wg.Done()
+		}()
+		go func() {
+			p.close()
+			wg.Done()
+		}()
+		go func() {
+			p.connect(ctx, state, false, false)
+			wg.Done()
+		}()
+		go func() {
+			p.connect(ctx, state, true, false)
+			wg.Done()
+		}()
+		wg.Wait()
+
+		if p.inProgress != 0 {
+			t.Errorf("unexpected query in progress")
+		}
+		if p.state != stopped {
+			t.Errorf("unexpected proxy state, expected %d, got %d", stopped, p.state)
+		}
+	}
+}