diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go index 5967c396c..6f9897550 100644 --- a/plugin/forward/connect.go +++ b/plugin/forward/connect.go @@ -5,6 +5,7 @@ package forward import ( + "io" "strconv" "time" @@ -22,7 +23,7 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me proto = "tcp" } - conn, err := p.Dial(proto) + conn, cached, err := p.Dial(proto) if err != nil { return nil, err } @@ -36,6 +37,9 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me conn.SetWriteDeadline(time.Now().Add(timeout)) if err := conn.WriteMsg(state.Req); err != nil { conn.Close() // not giving it back + if err == io.EOF && cached { + return nil, errCachedClosed + } return nil, err } @@ -43,6 +47,9 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me ret, err := conn.ReadMsg() if err != nil { conn.Close() // not giving it back + if err == io.EOF && cached { + return nil, errCachedClosed + } return nil, err } diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index 4c842f49e..6d06f79f2 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -7,7 +7,6 @@ package forward import ( "crypto/tls" "errors" - "io" "time" "github.com/coredns/coredns/plugin" @@ -92,11 +91,9 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg ret *dns.Msg err error ) - stop := false for { ret, err = proxy.connect(ctx, state, f.forceTCP, true) - if err != nil && err == io.EOF && !stop { // Remote side closed conn, can only happen with TCP. - stop = true + if err != nil && err == errCachedClosed { // Remote side closed conn, can only happen with TCP. continue } break @@ -176,6 +173,7 @@ var ( errInvalidDomain = errors.New("invalid domain for forward") errNoHealthy = errors.New("no healthy proxies") errNoForward = errors.New("no forwarder defined") + errCachedClosed = errors.New("cached connection was closed by peer") ) // policy tells forward what policy for selecting upstream it uses. diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go index 7bf083b49..6ea4d0371 100644 --- a/plugin/forward/persistent.go +++ b/plugin/forward/persistent.go @@ -16,8 +16,9 @@ type persistConn struct { // connErr is used to communicate the connection manager. type connErr struct { - c *dns.Conn - err error + c *dns.Conn + err error + cached bool } // transport hold the persistent cache. @@ -86,7 +87,7 @@ Wait: if time.Since(pc.used) < t.expire { // Found one, remove from pool and return this conn. t.conns[proto] = t.conns[proto][i+1:] - t.ret <- connErr{pc.c, nil} + t.ret <- connErr{pc.c, nil, true} continue Wait } // This conn has expired. Close it. @@ -100,12 +101,12 @@ Wait: go func() { if proto != "tcp-tls" { c, err := dns.DialTimeout(proto, t.addr, dialTimeout) - t.ret <- connErr{c, err} + t.ret <- connErr{c, err, false} return } c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout) - t.ret <- connErr{c, err} + t.ret <- connErr{c, err, false} }() case conn := <-t.yield: @@ -139,7 +140,7 @@ Wait: } // Dial dials the address configured in transport, potentially reusing a connection or creating a new one. -func (t *transport) Dial(proto string) (*dns.Conn, error) { +func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { // If tls has been configured; use it. if t.tlsConfig != nil { proto = "tcp-tls" @@ -147,12 +148,12 @@ func (t *transport) Dial(proto string) (*dns.Conn, error) { t.dial <- proto c := <-t.ret - return c.c, c.err + return c.c, c.cached, c.err } // Yield return the connection to transport for reuse. func (t *transport) Yield(c *dns.Conn) { - t.yield <- connErr{c, nil} + t.yield <- connErr{c, nil, false} } // Stop stops the transport's connection manager. diff --git a/plugin/forward/persistent_test.go b/plugin/forward/persistent_test.go index 5fa491a01..f4f476afa 100644 --- a/plugin/forward/persistent_test.go +++ b/plugin/forward/persistent_test.go @@ -19,9 +19,13 @@ func TestPersistent(t *testing.T) { tr := newTransport(s.Addr, nil /* no TLS */) defer tr.Stop() - c1, _ := tr.Dial("udp") - c2, _ := tr.Dial("udp") - c3, _ := tr.Dial("udp") + c1, cache1, _ := tr.Dial("udp") + c2, cache2, _ := tr.Dial("udp") + c3, cache3, _ := tr.Dial("udp") + + if cache1 || cache2 || cache3 { + t.Errorf("Expected non-cached connection") + } tr.Yield(c1) tr.Yield(c2) @@ -31,13 +35,23 @@ func TestPersistent(t *testing.T) { t.Errorf("Expected cache size to be 3, got %d", x) } - tr.Dial("udp") + c4, cache4, _ := tr.Dial("udp") if x := tr.Len(); x != 2 { t.Errorf("Expected cache size to be 2, got %d", x) } - tr.Dial("udp") + c5, cache5, _ := tr.Dial("udp") if x := tr.Len(); x != 1 { - t.Errorf("Expected cache size to be 2, got %d", x) + t.Errorf("Expected cache size to be 1, got %d", x) + } + + if cache4 == false || cache5 == false { + t.Errorf("Expected cached connection") + } + tr.Yield(c4) + tr.Yield(c5) + + if x := tr.Len(); x != 3 { + t.Errorf("Expected cache size to be 3, got %d", x) } } diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go index 30bab52d1..02d3512cb 100644 --- a/plugin/forward/proxy.go +++ b/plugin/forward/proxy.go @@ -58,7 +58,7 @@ func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.transport.SetTLSConfig(cfg) } func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) } // Dial connects to the host in p with the configured transport. -func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(proto) } +func (p *Proxy) Dial(proto string) (*dns.Conn, bool, error) { return p.transport.Dial(proto) } // Yield returns the connection to the pool. func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) }