plugin/forward: make Yield not block (#3336)

* plugin/forward: may Yield not block

Yield may block when we're super busy with creating (and looking) for
connection. Set a small timeout on Yield, to skip putting the connection
back in the queue.

Use persistentConn troughout the socket handling code to be more
consistent.

Signed-off-by: Miek Gieben <miek@miek.nl>

Dont do

Signed-off-by: Miek Gieben <miek@miek.nl>

* Set used in Yield

This gives one central place where we update used in the persistConns

Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
Miek Gieben 2019-10-01 16:39:42 +01:00 committed by GitHub
parent 7b69dfebb5
commit 2d98d520b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 53 additions and 105 deletions

View file

@ -102,7 +102,6 @@ If monitoring is enabled (via the *prometheus* directive) then the following met
* `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream. * `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream.
* `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy, * `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy,
and we are randomly (this always uses the `random` policy) spraying to an upstream. and we are randomly (this always uses the `random` policy) spraying to an upstream.
* `coredns_forward_socket_count_total{to}` - number of cached sockets per upstream.
Where `to` is one of the upstream servers (**TO** from the config), `rcode` is the returned RCODE Where `to` is one of the upstream servers (**TO** from the config), `rcode` is the returned RCODE
from the upstream. from the upstream.

View file

@ -44,17 +44,17 @@ func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
} }
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. // Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) { func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
// If tls has been configured; use it. // If tls has been configured; use it.
if t.tlsConfig != nil { if t.tlsConfig != nil {
proto = "tcp-tls" proto = "tcp-tls"
} }
t.dial <- proto t.dial <- proto
c := <-t.ret pc := <-t.ret
if c != nil { if pc != nil {
return c, true, nil return pc, true, nil
} }
reqTime := time.Now() reqTime := time.Now()
@ -62,11 +62,11 @@ func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) {
if proto == "tcp-tls" { if proto == "tcp-tls" {
conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
t.updateDialTimeout(time.Since(reqTime)) t.updateDialTimeout(time.Since(reqTime))
return conn, false, err return &persistConn{c: conn}, false, err
} }
conn, err := dns.DialTimeout(proto, t.addr, timeout) conn, err := dns.DialTimeout(proto, t.addr, timeout)
t.updateDialTimeout(time.Since(reqTime)) t.updateDialTimeout(time.Since(reqTime))
return conn, false, err return &persistConn{c: conn}, false, err
} }
// Connect selects an upstream, sends the request and waits for a response. // Connect selects an upstream, sends the request and waits for a response.
@ -83,20 +83,20 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options
proto = state.Proto() proto = state.Proto()
} }
conn, cached, err := p.transport.Dial(proto) pc, cached, err := p.transport.Dial(proto)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Set buffer size correctly for this client. // Set buffer size correctly for this client.
conn.UDPSize = uint16(state.Size()) pc.c.UDPSize = uint16(state.Size())
if conn.UDPSize < 512 { if pc.c.UDPSize < 512 {
conn.UDPSize = 512 pc.c.UDPSize = 512
} }
conn.SetWriteDeadline(time.Now().Add(maxTimeout)) pc.c.SetWriteDeadline(time.Now().Add(maxTimeout))
if err := conn.WriteMsg(state.Req); err != nil { if err := pc.c.WriteMsg(state.Req); err != nil {
conn.Close() // not giving it back pc.c.Close() // not giving it back
if err == io.EOF && cached { if err == io.EOF && cached {
return nil, ErrCachedClosed return nil, ErrCachedClosed
} }
@ -104,11 +104,11 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options
} }
var ret *dns.Msg var ret *dns.Msg
conn.SetReadDeadline(time.Now().Add(readTimeout)) pc.c.SetReadDeadline(time.Now().Add(readTimeout))
for { for {
ret, err = conn.ReadMsg() ret, err = pc.c.ReadMsg()
if err != nil { if err != nil {
conn.Close() // not giving it back pc.c.Close() // not giving it back
if err == io.EOF && cached { if err == io.EOF && cached {
return nil, ErrCachedClosed return nil, ErrCachedClosed
} }
@ -120,7 +120,7 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options
} }
} }
p.transport.Yield(conn) p.transport.Yield(pc)
rc, ok := dns.RcodeToString[ret.Rcode] rc, ok := dns.RcodeToString[ret.Rcode]
if !ok { if !ok {

View file

@ -29,7 +29,7 @@ func TestHealth(t *testing.T) {
p := NewProxy(s.Addr, transport.DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.OnShutdown()
req := new(dns.Msg) req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA) req.SetQuestion("example.org.", dns.TypeA)
@ -69,7 +69,7 @@ func TestHealthTimeout(t *testing.T) {
p := NewProxy(s.Addr, transport.DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.OnShutdown()
req := new(dns.Msg) req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA) req.SetQuestion("example.org.", dns.TypeA)
@ -113,7 +113,7 @@ func TestHealthFailTwice(t *testing.T) {
p := NewProxy(s.Addr, transport.DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.OnShutdown()
req := new(dns.Msg) req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA) req.SetQuestion("example.org.", dns.TypeA)
@ -137,7 +137,7 @@ func TestHealthMaxFails(t *testing.T) {
f := New() f := New()
f.maxfails = 2 f.maxfails = 2
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.OnShutdown()
req := new(dns.Msg) req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA) req.SetQuestion("example.org.", dns.TypeA)
@ -169,7 +169,7 @@ func TestHealthNoMaxFails(t *testing.T) {
f := New() f := New()
f.maxfails = 0 f.maxfails = 0
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.OnShutdown()
req := new(dns.Msg) req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA) req.SetQuestion("example.org.", dns.TypeA)

View file

@ -24,8 +24,8 @@ type Transport struct {
tlsConfig *tls.Config tlsConfig *tls.Config
dial chan string dial chan string
yield chan *dns.Conn yield chan *persistConn
ret chan *dns.Conn ret chan *persistConn
stop chan bool stop chan bool
} }
@ -36,23 +36,13 @@ func newTransport(addr string) *Transport {
expire: defaultExpire, expire: defaultExpire,
addr: addr, addr: addr,
dial: make(chan string), dial: make(chan string),
yield: make(chan *dns.Conn), yield: make(chan *persistConn),
ret: make(chan *dns.Conn), ret: make(chan *persistConn),
stop: make(chan bool), stop: make(chan bool),
} }
return t return t
} }
// len returns the number of connection, used for metrics. Can only be safely
// used inside connManager() because of data races.
func (t *Transport) len() int {
l := 0
for _, conns := range t.conns {
l += len(conns)
}
return l
}
// connManagers manages the persistent connection cache for UDP and TCP. // connManagers manages the persistent connection cache for UDP and TCP.
func (t *Transport) connManager() { func (t *Transport) connManager() {
ticker := time.NewTicker(t.expire) ticker := time.NewTicker(t.expire)
@ -66,7 +56,7 @@ Wait:
if time.Since(pc.used) < t.expire { if time.Since(pc.used) < t.expire {
// Found one, remove from pool and return this conn. // Found one, remove from pool and return this conn.
t.conns[proto] = stack[:len(stack)-1] t.conns[proto] = stack[:len(stack)-1]
t.ret <- pc.c t.ret <- pc
continue Wait continue Wait
} }
// clear entire cache if the last conn is expired // clear entire cache if the last conn is expired
@ -75,26 +65,21 @@ Wait:
// transport methods anymore. So, it's safe to close them in a separate goroutine // transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack) go closeConns(stack)
} }
SocketGauge.WithLabelValues(t.addr).Set(float64(t.len()))
t.ret <- nil t.ret <- nil
case conn := <-t.yield: case pc := <-t.yield:
SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1))
// no proto here, infer from config and conn // no proto here, infer from config and conn
if _, ok := conn.Conn.(*net.UDPConn); ok { if _, ok := pc.c.Conn.(*net.UDPConn); ok {
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()}) t.conns["udp"] = append(t.conns["udp"], pc)
continue Wait continue Wait
} }
if t.tlsConfig == nil { if t.tlsConfig == nil {
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()}) t.conns["tcp"] = append(t.conns["tcp"], pc)
continue Wait continue Wait
} }
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) t.conns["tcp-tls"] = append(t.conns["tcp-tls"], pc)
case <-ticker.C: case <-ticker.C:
t.cleanup(false) t.cleanup(false)
@ -143,8 +128,23 @@ func (t *Transport) cleanup(all bool) {
} }
} }
// It is hard to pin a value to this, the import thing is to no block forever, loosing at cached connection is not terrible.
const yieldTimeout = 25 * time.Millisecond
// Yield return the connection to transport for reuse. // Yield return the connection to transport for reuse.
func (t *Transport) Yield(c *dns.Conn) { t.yield <- c } func (t *Transport) Yield(pc *persistConn) {
pc.used = time.Now() // update used time
// Make ths non-blocking, because in the case of a very busy forwarder we will *block* on this yield. This
// blocks the outer go-routine and stuff will just pile up. We timeout when the send fails to as returning
// these connection is an optimization anyway.
select {
case t.yield <- pc:
return
case <-time.After(yieldTimeout):
return
}
}
// Start starts the transport's connection manager. // Start starts the transport's connection manager.
func (t *Transport) Start() { go t.connManager() } func (t *Transport) Start() { go t.connManager() }

View file

@ -82,54 +82,6 @@ func TestCleanupByTimer(t *testing.T) {
tr.Yield(c4) tr.Yield(c4)
} }
func TestPartialCleanup(t *testing.T) {
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg)
ret.SetReply(r)
w.WriteMsg(ret)
})
defer s.Close()
tr := newTransport(s.Addr)
tr.SetExpire(100 * time.Millisecond)
tr.Start()
defer tr.Stop()
c1, _, _ := tr.Dial("udp")
c2, _, _ := tr.Dial("udp")
c3, _, _ := tr.Dial("udp")
c4, _, _ := tr.Dial("udp")
c5, _, _ := tr.Dial("udp")
tr.Yield(c1)
time.Sleep(10 * time.Millisecond)
tr.Yield(c2)
time.Sleep(10 * time.Millisecond)
tr.Yield(c3)
time.Sleep(50 * time.Millisecond)
tr.Yield(c4)
time.Sleep(10 * time.Millisecond)
tr.Yield(c5)
time.Sleep(40 * time.Millisecond)
c6, _, _ := tr.Dial("udp")
if c6 != c5 {
t.Errorf("Expected c6 == c5")
}
c7, _, _ := tr.Dial("udp")
if c7 != c4 {
t.Errorf("Expected c7 == c4")
}
c8, cached, _ := tr.Dial("udp")
if cached {
t.Error("Expected non-cached connection (c8)")
}
tr.Yield(c6)
tr.Yield(c7)
tr.Yield(c8)
}
func TestCleanupAll(t *testing.T) { func TestCleanupAll(t *testing.T) {
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg) ret := new(dns.Msg)
@ -150,12 +102,12 @@ func TestCleanupAll(t *testing.T) {
{c3, time.Now()}, {c3, time.Now()},
} }
if tr.len() != 3 { if len(tr.conns["udp"]) != 3 {
t.Error("Expected 3 connections") t.Error("Expected 3 connections")
} }
tr.cleanup(true) tr.cleanup(true)
if tr.len() > 0 { if len(tr.conns["udp"]) > 0 {
t.Error("Expected no cached connections") t.Error("Expected no cached connections")
} }
} }

View file

@ -66,7 +66,7 @@ func (p *Proxy) Down(maxfails uint32) bool {
} }
// close stops the health checking goroutine. // close stops the health checking goroutine.
func (p *Proxy) close() { p.probe.Stop() } func (p *Proxy) stop() { p.probe.Stop() }
func (p *Proxy) finalizer() { p.transport.Stop() } func (p *Proxy) finalizer() { p.transport.Stop() }
// start starts the proxy's healthchecking. // start starts the proxy's healthchecking.

View file

@ -35,7 +35,7 @@ func TestProxyClose(t *testing.T) {
go func() { p.Connect(ctx, state, options{}) }() go func() { p.Connect(ctx, state, options{}) }()
go func() { p.Connect(ctx, state, options{forceTCP: true}) }() go func() { p.Connect(ctx, state, options{forceTCP: true}) }()
p.close() p.stop()
} }
} }

View file

@ -54,14 +54,11 @@ func (f *Forward) OnStartup() (err error) {
// OnShutdown stops all configured proxies. // OnShutdown stops all configured proxies.
func (f *Forward) OnShutdown() error { func (f *Forward) OnShutdown() error {
for _, p := range f.proxies { for _, p := range f.proxies {
p.close() p.stop()
} }
return nil return nil
} }
// Close is a synonym for OnShutdown().
func (f *Forward) Close() { f.OnShutdown() }
func parseForward(c *caddy.Controller) (*Forward, error) { func parseForward(c *caddy.Controller) (*Forward, error) {
var ( var (
f *Forward f *Forward