plugin/forward: gracefull stop (#1701)
* plugin/forward: gracefull stop - stop connection manager only when no queries in progress * minor improvement * prevent healthcheck on stopped proxy * revert closing channels * use standard context
This commit is contained in:
parent
ad13d88346
commit
135377bf77
4 changed files with 100 additions and 5 deletions
|
@ -35,6 +35,16 @@ func (p *Proxy) updateRtt(newRtt time.Duration) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
|
func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
|
||||||
|
atomic.AddInt32(&p.inProgress, 1)
|
||||||
|
defer func() {
|
||||||
|
if atomic.AddInt32(&p.inProgress, -1) == 0 {
|
||||||
|
p.checkStopTransport()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if atomic.LoadUint32(&p.state) != running {
|
||||||
|
return nil, errStopped
|
||||||
|
}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
proto := state.Proto()
|
proto := state.Proto()
|
||||||
|
@ -46,7 +56,6 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set buffer size correctly for this client.
|
// Set buffer size correctly for this client.
|
||||||
conn.UDPSize = uint16(state.Size())
|
conn.UDPSize = uint16(state.Size())
|
||||||
if conn.UDPSize < 512 {
|
if conn.UDPSize < 512 {
|
||||||
|
|
|
@ -120,7 +120,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Kick off health check to see if *our* upstream is broken.
|
// Kick off health check to see if *our* upstream is broken.
|
||||||
if f.maxfails != 0 {
|
if f.maxfails != 0 && err != errStopped {
|
||||||
proxy.Healthcheck()
|
proxy.Healthcheck()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,6 +186,7 @@ var (
|
||||||
errNoHealthy = errors.New("no healthy proxies")
|
errNoHealthy = errors.New("no healthy proxies")
|
||||||
errNoForward = errors.New("no forwarder defined")
|
errNoForward = errors.New("no forwarder defined")
|
||||||
errCachedClosed = errors.New("cached connection was closed by peer")
|
errCachedClosed = errors.New("cached connection was closed by peer")
|
||||||
|
errStopped = errors.New("proxy has been stopped")
|
||||||
)
|
)
|
||||||
|
|
||||||
// policy tells forward what policy for selecting upstream it uses.
|
// policy tells forward what policy for selecting upstream it uses.
|
||||||
|
|
|
@ -24,6 +24,9 @@ type Proxy struct {
|
||||||
fails uint32
|
fails uint32
|
||||||
|
|
||||||
avgRtt int64
|
avgRtt int64
|
||||||
|
|
||||||
|
state uint32
|
||||||
|
inProgress int32
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewProxy returns a new proxy.
|
// NewProxy returns a new proxy.
|
||||||
|
@ -79,15 +82,26 @@ func (p *Proxy) Down(maxfails uint32) bool {
|
||||||
return fails > maxfails
|
return fails > maxfails
|
||||||
}
|
}
|
||||||
|
|
||||||
// close stops the health checking goroutine.
|
// close stops the health checking goroutine and connection manager.
|
||||||
func (p *Proxy) close() {
|
func (p *Proxy) close() {
|
||||||
|
if atomic.CompareAndSwapUint32(&p.state, running, stopping) {
|
||||||
p.probe.Stop()
|
p.probe.Stop()
|
||||||
p.transport.Stop()
|
}
|
||||||
|
if atomic.LoadInt32(&p.inProgress) == 0 {
|
||||||
|
p.checkStopTransport()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// start starts the proxy's healthchecking.
|
// start starts the proxy's healthchecking.
|
||||||
func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) }
|
func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) }
|
||||||
|
|
||||||
|
// checkStopTransport checks if stop was requested and stops connection manager
|
||||||
|
func (p *Proxy) checkStopTransport() {
|
||||||
|
if atomic.CompareAndSwapUint32(&p.state, stopping, stopped) {
|
||||||
|
p.transport.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dialTimeout = 4 * time.Second
|
dialTimeout = 4 * time.Second
|
||||||
timeout = 2 * time.Second
|
timeout = 2 * time.Second
|
||||||
|
@ -95,3 +109,9 @@ const (
|
||||||
minTimeout = 10 * time.Millisecond
|
minTimeout = 10 * time.Millisecond
|
||||||
hcDuration = 500 * time.Millisecond
|
hcDuration = 500 * time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
running = iota
|
||||||
|
stopping
|
||||||
|
stopped
|
||||||
|
)
|
||||||
|
|
65
plugin/forward/proxy_test.go
Normal file
65
plugin/forward/proxy_test.go
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
package forward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||||
|
"github.com/coredns/coredns/plugin/test"
|
||||||
|
"github.com/coredns/coredns/request"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProxyClose(t *testing.T) {
|
||||||
|
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
ret := new(dns.Msg)
|
||||||
|
ret.SetReply(r)
|
||||||
|
w.WriteMsg(ret)
|
||||||
|
})
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := new(dns.Msg)
|
||||||
|
req.SetQuestion("example.org.", dns.TypeA)
|
||||||
|
state := request.Request{W: &test.ResponseWriter{}, Req: req}
|
||||||
|
ctx := context.TODO()
|
||||||
|
|
||||||
|
repeatCnt := 1000
|
||||||
|
for repeatCnt > 0 {
|
||||||
|
repeatCnt--
|
||||||
|
p := NewProxy(s.Addr, nil /* no TLS */)
|
||||||
|
p.start(hcDuration)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(5)
|
||||||
|
go func() {
|
||||||
|
p.connect(ctx, state, false, false)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
p.connect(ctx, state, true, false)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
p.close()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
p.connect(ctx, state, false, false)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
p.connect(ctx, state, true, false)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if p.inProgress != 0 {
|
||||||
|
t.Errorf("unexpected query in progress")
|
||||||
|
}
|
||||||
|
if p.state != stopped {
|
||||||
|
t.Errorf("unexpected proxy state, expected %d, got %d", stopped, p.state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue