diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go index 0a66f2752..6ea7913e5 100644 --- a/plugin/forward/connect.go +++ b/plugin/forward/connect.go @@ -35,6 +35,16 @@ func (p *Proxy) updateRtt(newRtt time.Duration) { } func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) { + atomic.AddInt32(&p.inProgress, 1) + defer func() { + if atomic.AddInt32(&p.inProgress, -1) == 0 { + p.checkStopTransport() + } + }() + if atomic.LoadUint32(&p.state) != running { + return nil, errStopped + } + start := time.Now() proto := state.Proto() @@ -46,7 +56,6 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me if err != nil { return nil, err } - // Set buffer size correctly for this client. conn.UDPSize = uint16(state.Size()) if conn.UDPSize < 512 { diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index 153c5ab38..213b30f8b 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -120,7 +120,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg if err != nil { // Kick off health check to see if *our* upstream is broken. - if f.maxfails != 0 { + if f.maxfails != 0 && err != errStopped { proxy.Healthcheck() } @@ -186,6 +186,7 @@ var ( errNoHealthy = errors.New("no healthy proxies") errNoForward = errors.New("no forwarder defined") errCachedClosed = errors.New("cached connection was closed by peer") + errStopped = errors.New("proxy has been stopped") ) // policy tells forward what policy for selecting upstream it uses. diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go index 3271e7dd9..8454b296d 100644 --- a/plugin/forward/proxy.go +++ b/plugin/forward/proxy.go @@ -24,6 +24,9 @@ type Proxy struct { fails uint32 avgRtt int64 + + state uint32 + inProgress int32 } // NewProxy returns a new proxy. @@ -79,15 +82,26 @@ func (p *Proxy) Down(maxfails uint32) bool { return fails > maxfails } -// close stops the health checking goroutine. +// close stops the health checking goroutine and connection manager. func (p *Proxy) close() { - p.probe.Stop() - p.transport.Stop() + if atomic.CompareAndSwapUint32(&p.state, running, stopping) { + p.probe.Stop() + } + if atomic.LoadInt32(&p.inProgress) == 0 { + p.checkStopTransport() + } } // start starts the proxy's healthchecking. func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) } +// checkStopTransport checks if stop was requested and stops connection manager +func (p *Proxy) checkStopTransport() { + if atomic.CompareAndSwapUint32(&p.state, stopping, stopped) { + p.transport.Stop() + } +} + const ( dialTimeout = 4 * time.Second timeout = 2 * time.Second @@ -95,3 +109,9 @@ const ( minTimeout = 10 * time.Millisecond hcDuration = 500 * time.Millisecond ) + +const ( + running = iota + stopping + stopped +) diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go new file mode 100644 index 000000000..8c53f3150 --- /dev/null +++ b/plugin/forward/proxy_test.go @@ -0,0 +1,65 @@ +package forward + +import ( + "context" + "sync" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestProxyClose(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + state := request.Request{W: &test.ResponseWriter{}, Req: req} + ctx := context.TODO() + + repeatCnt := 1000 + for repeatCnt > 0 { + repeatCnt-- + p := NewProxy(s.Addr, nil /* no TLS */) + p.start(hcDuration) + + var wg sync.WaitGroup + wg.Add(5) + go func() { + p.connect(ctx, state, false, false) + wg.Done() + }() + go func() { + p.connect(ctx, state, true, false) + wg.Done() + }() + go func() { + p.close() + wg.Done() + }() + go func() { + p.connect(ctx, state, false, false) + wg.Done() + }() + go func() { + p.connect(ctx, state, true, false) + wg.Done() + }() + wg.Wait() + + if p.inProgress != 0 { + t.Errorf("unexpected query in progress") + } + if p.state != stopped { + t.Errorf("unexpected proxy state, expected %d, got %d", stopped, p.state) + } + } +}