diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go index 40c9d62ca..5bd55f2ab 100644 --- a/plugin/forward/connect.go +++ b/plugin/forward/connect.go @@ -34,16 +34,6 @@ func (p *Proxy) updateRtt(newRtt time.Duration) { } func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) { - atomic.AddInt32(&p.inProgress, 1) - defer func() { - if atomic.AddInt32(&p.inProgress, -1) == 0 { - p.checkStopTransport() - } - }() - if atomic.LoadUint32(&p.state) != running { - return nil, errStopped - } - start := time.Now() proto := state.Proto() @@ -55,6 +45,7 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me if err != nil { return nil, err } + // Set buffer size correctly for this client. conn.UDPSize = uint16(state.Size()) if conn.UDPSize < 512 { diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index 84339d4bd..20d995710 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -119,7 +119,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg if err != nil { // Kick off health check to see if *our* upstream is broken. - if f.maxfails != 0 && err != errStopped { + if f.maxfails != 0 { proxy.Healthcheck() } @@ -185,7 +185,6 @@ var ( errNoHealthy = errors.New("no healthy proxies") errNoForward = errors.New("no forwarder defined") errCachedClosed = errors.New("cached connection was closed by peer") - errStopped = errors.New("proxy has been stopped") ) // policy tells forward what policy for selecting upstream it uses. diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go index decac412c..dc03002d3 100644 --- a/plugin/forward/persistent.go +++ b/plugin/forward/persistent.go @@ -14,13 +14,6 @@ type persistConn struct { used time.Time } -// connErr is used to communicate the connection manager. -type connErr struct { - c *dns.Conn - err error - cached bool -} - // transport hold the persistent cache. type transport struct { conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. @@ -29,8 +22,8 @@ type transport struct { tlsConfig *tls.Config dial chan string - yield chan connErr - ret chan connErr + yield chan *dns.Conn + ret chan *dns.Conn stop chan bool } @@ -40,18 +33,11 @@ func newTransport(addr string, tlsConfig *tls.Config) *transport { expire: defaultExpire, addr: addr, dial: make(chan string), - yield: make(chan connErr), - ret: make(chan connErr), + yield: make(chan *dns.Conn), + ret: make(chan *dns.Conn), stop: make(chan bool), } - go func() { - t.connManager() - // if connManager returns it has been stopped. - close(t.stop) - close(t.yield) - close(t.dial) - // close(t.ret) // we can still be dialing and wanting to send back the socket on t.ret - }() + go func() { t.connManager() }() return t } @@ -80,7 +66,7 @@ Wait: if time.Since(pc.used) < t.expire { // Found one, remove from pool and return this conn. t.conns[proto] = t.conns[proto][i+1:] - t.ret <- connErr{pc.c, nil, true} + t.ret <- pc.c continue Wait } // This conn has expired. Close it. @@ -91,35 +77,27 @@ Wait: t.conns[proto] = t.conns[proto][i:] SocketGauge.WithLabelValues(t.addr).Set(float64(t.len())) - go func() { - if proto != "tcp-tls" { - c, err := dns.DialTimeout(proto, t.addr, dialTimeout) - t.ret <- connErr{c, err, false} - return - } - - c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout) - t.ret <- connErr{c, err, false} - }() + t.ret <- nil case conn := <-t.yield: SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1)) // no proto here, infer from config and conn - if _, ok := conn.c.Conn.(*net.UDPConn); ok { - t.conns["udp"] = append(t.conns["udp"], &persistConn{conn.c, time.Now()}) + if _, ok := conn.Conn.(*net.UDPConn); ok { + t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()}) continue Wait } if t.tlsConfig == nil { - t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()}) + t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()}) continue Wait } - t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn.c, time.Now()}) + t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) case <-t.stop: + close(t.ret) return } } @@ -134,16 +112,24 @@ func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { t.dial <- proto c := <-t.ret - return c.c, c.cached, c.err + + if c != nil { + return c, true, nil + } + + if proto == "tcp-tls" { + conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout) + return conn, false, err + } + conn, err := dns.DialTimeout(proto, t.addr, dialTimeout) + return conn, false, err } // Yield return the connection to transport for reuse. -func (t *transport) Yield(c *dns.Conn) { - t.yield <- connErr{c, nil, false} -} +func (t *transport) Yield(c *dns.Conn) { t.yield <- c } // Stop stops the transport's connection manager. -func (t *transport) Stop() { t.stop <- true } +func (t *transport) Stop() { close(t.stop) } // SetExpire sets the connection expire time in transport. func (t *transport) SetExpire(expire time.Duration) { t.expire = expire } diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go index c788f98cc..b6b570149 100644 --- a/plugin/forward/proxy.go +++ b/plugin/forward/proxy.go @@ -24,9 +24,6 @@ type Proxy struct { fails uint32 avgRtt int64 - - state uint32 - inProgress int32 } // NewProxy returns a new proxy. @@ -85,26 +82,15 @@ func (p *Proxy) Down(maxfails uint32) bool { return fails > maxfails } -// close stops the health checking goroutine and connection manager. +// close stops the health checking goroutine. func (p *Proxy) close() { - if atomic.CompareAndSwapUint32(&p.state, running, stopping) { - p.probe.Stop() - } - if atomic.LoadInt32(&p.inProgress) == 0 { - p.checkStopTransport() - } + p.probe.Stop() + p.transport.Stop() } // start starts the proxy's healthchecking. func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) } -// checkStopTransport checks if stop was requested and stops connection manager -func (p *Proxy) checkStopTransport() { - if atomic.CompareAndSwapUint32(&p.state, stopping, stopped) { - p.transport.Stop() - } -} - const ( dialTimeout = 4 * time.Second timeout = 2 * time.Second @@ -112,9 +98,3 @@ const ( minTimeout = 10 * time.Millisecond hcDuration = 500 * time.Millisecond ) - -const ( - running = iota - stopping - stopped -) diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go index acd3d240c..d473d6881 100644 --- a/plugin/forward/proxy_test.go +++ b/plugin/forward/proxy_test.go @@ -2,9 +2,7 @@ package forward import ( "context" - "runtime" "testing" - "time" "github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/test" @@ -28,50 +26,15 @@ func TestProxyClose(t *testing.T) { ctx := context.TODO() for i := 0; i < 100; i++ { - p := NewProxy(s.Addr, nil /* no TLS */) + p := NewProxy(s.Addr, nil) p.start(hcDuration) - doneCnt := 0 - doneCh := make(chan bool) - timeCh := time.After(10 * time.Second) - go func() { - p.connect(ctx, state, false, false) - doneCh <- true - }() - go func() { - p.connect(ctx, state, true, false) - doneCh <- true - }() - go func() { - p.close() - doneCh <- true - }() - go func() { - p.connect(ctx, state, false, false) - doneCh <- true - }() - go func() { - p.connect(ctx, state, true, false) - doneCh <- true - }() + go func() { p.connect(ctx, state, false, false) }() + go func() { p.connect(ctx, state, true, false) }() + go func() { p.connect(ctx, state, false, false) }() + go func() { p.connect(ctx, state, true, false) }() - for doneCnt < 5 { - select { - case <-doneCh: - doneCnt++ - case <-timeCh: - t.Error("TestProxyClose is running too long, dumping goroutines:") - buf := make([]byte, 100000) - stackSize := runtime.Stack(buf, true) - t.Fatal(string(buf[:stackSize])) - } - } - if p.inProgress != 0 { - t.Errorf("unexpected query in progress") - } - if p.state != stopped { - t.Errorf("unexpected proxy state, expected %d, got %d", stopped, p.state) - } + p.close() } }