plugin/forward: retry on cached tcp connection closed by peer (#1655)
* plugin/forward: retry on cached tcp connection closed by peer * fix linter warnings * fixed unit test * modify error message
This commit is contained in:
parent
848a5d7c79
commit
e46ee9d9cc
5 changed files with 40 additions and 20 deletions
|
@ -5,6 +5,7 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
|
@ -22,7 +23,7 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
|
|||
proto = "tcp"
|
||||
}
|
||||
|
||||
conn, err := p.Dial(proto)
|
||||
conn, cached, err := p.Dial(proto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -36,6 +37,9 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
|
|||
conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
if err := conn.WriteMsg(state.Req); err != nil {
|
||||
conn.Close() // not giving it back
|
||||
if err == io.EOF && cached {
|
||||
return nil, errCachedClosed
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -43,6 +47,9 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
|
|||
ret, err := conn.ReadMsg()
|
||||
if err != nil {
|
||||
conn.Close() // not giving it back
|
||||
if err == io.EOF && cached {
|
||||
return nil, errCachedClosed
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ package forward
|
|||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
|
@ -92,11 +91,9 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
|
|||
ret *dns.Msg
|
||||
err error
|
||||
)
|
||||
stop := false
|
||||
for {
|
||||
ret, err = proxy.connect(ctx, state, f.forceTCP, true)
|
||||
if err != nil && err == io.EOF && !stop { // Remote side closed conn, can only happen with TCP.
|
||||
stop = true
|
||||
if err != nil && err == errCachedClosed { // Remote side closed conn, can only happen with TCP.
|
||||
continue
|
||||
}
|
||||
break
|
||||
|
@ -176,6 +173,7 @@ var (
|
|||
errInvalidDomain = errors.New("invalid domain for forward")
|
||||
errNoHealthy = errors.New("no healthy proxies")
|
||||
errNoForward = errors.New("no forwarder defined")
|
||||
errCachedClosed = errors.New("cached connection was closed by peer")
|
||||
)
|
||||
|
||||
// policy tells forward what policy for selecting upstream it uses.
|
||||
|
|
|
@ -16,8 +16,9 @@ type persistConn struct {
|
|||
|
||||
// connErr is used to communicate the connection manager.
|
||||
type connErr struct {
|
||||
c *dns.Conn
|
||||
err error
|
||||
c *dns.Conn
|
||||
err error
|
||||
cached bool
|
||||
}
|
||||
|
||||
// transport hold the persistent cache.
|
||||
|
@ -86,7 +87,7 @@ Wait:
|
|||
if time.Since(pc.used) < t.expire {
|
||||
// Found one, remove from pool and return this conn.
|
||||
t.conns[proto] = t.conns[proto][i+1:]
|
||||
t.ret <- connErr{pc.c, nil}
|
||||
t.ret <- connErr{pc.c, nil, true}
|
||||
continue Wait
|
||||
}
|
||||
// This conn has expired. Close it.
|
||||
|
@ -100,12 +101,12 @@ Wait:
|
|||
go func() {
|
||||
if proto != "tcp-tls" {
|
||||
c, err := dns.DialTimeout(proto, t.addr, dialTimeout)
|
||||
t.ret <- connErr{c, err}
|
||||
t.ret <- connErr{c, err, false}
|
||||
return
|
||||
}
|
||||
|
||||
c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout)
|
||||
t.ret <- connErr{c, err}
|
||||
t.ret <- connErr{c, err, false}
|
||||
}()
|
||||
|
||||
case conn := <-t.yield:
|
||||
|
@ -139,7 +140,7 @@ Wait:
|
|||
}
|
||||
|
||||
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
|
||||
func (t *transport) Dial(proto string) (*dns.Conn, error) {
|
||||
func (t *transport) Dial(proto string) (*dns.Conn, bool, error) {
|
||||
// If tls has been configured; use it.
|
||||
if t.tlsConfig != nil {
|
||||
proto = "tcp-tls"
|
||||
|
@ -147,12 +148,12 @@ func (t *transport) Dial(proto string) (*dns.Conn, error) {
|
|||
|
||||
t.dial <- proto
|
||||
c := <-t.ret
|
||||
return c.c, c.err
|
||||
return c.c, c.cached, c.err
|
||||
}
|
||||
|
||||
// Yield return the connection to transport for reuse.
|
||||
func (t *transport) Yield(c *dns.Conn) {
|
||||
t.yield <- connErr{c, nil}
|
||||
t.yield <- connErr{c, nil, false}
|
||||
}
|
||||
|
||||
// Stop stops the transport's connection manager.
|
||||
|
|
|
@ -19,9 +19,13 @@ func TestPersistent(t *testing.T) {
|
|||
tr := newTransport(s.Addr, nil /* no TLS */)
|
||||
defer tr.Stop()
|
||||
|
||||
c1, _ := tr.Dial("udp")
|
||||
c2, _ := tr.Dial("udp")
|
||||
c3, _ := tr.Dial("udp")
|
||||
c1, cache1, _ := tr.Dial("udp")
|
||||
c2, cache2, _ := tr.Dial("udp")
|
||||
c3, cache3, _ := tr.Dial("udp")
|
||||
|
||||
if cache1 || cache2 || cache3 {
|
||||
t.Errorf("Expected non-cached connection")
|
||||
}
|
||||
|
||||
tr.Yield(c1)
|
||||
tr.Yield(c2)
|
||||
|
@ -31,13 +35,23 @@ func TestPersistent(t *testing.T) {
|
|||
t.Errorf("Expected cache size to be 3, got %d", x)
|
||||
}
|
||||
|
||||
tr.Dial("udp")
|
||||
c4, cache4, _ := tr.Dial("udp")
|
||||
if x := tr.Len(); x != 2 {
|
||||
t.Errorf("Expected cache size to be 2, got %d", x)
|
||||
}
|
||||
|
||||
tr.Dial("udp")
|
||||
c5, cache5, _ := tr.Dial("udp")
|
||||
if x := tr.Len(); x != 1 {
|
||||
t.Errorf("Expected cache size to be 2, got %d", x)
|
||||
t.Errorf("Expected cache size to be 1, got %d", x)
|
||||
}
|
||||
|
||||
if cache4 == false || cache5 == false {
|
||||
t.Errorf("Expected cached connection")
|
||||
}
|
||||
tr.Yield(c4)
|
||||
tr.Yield(c5)
|
||||
|
||||
if x := tr.Len(); x != 3 {
|
||||
t.Errorf("Expected cache size to be 3, got %d", x)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.transport.SetTLSConfig(cfg) }
|
|||
func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) }
|
||||
|
||||
// Dial connects to the host in p with the configured transport.
|
||||
func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(proto) }
|
||||
func (p *Proxy) Dial(proto string) (*dns.Conn, bool, error) { return p.transport.Dial(proto) }
|
||||
|
||||
// Yield returns the connection to the pool.
|
||||
func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) }
|
||||
|
|
Loading…
Add table
Reference in a new issue