diff --git a/plugin/forward/README.md b/plugin/forward/README.md index ae7ce67ba..fd00253dd 100644 --- a/plugin/forward/README.md +++ b/plugin/forward/README.md @@ -102,7 +102,6 @@ If monitoring is enabled (via the *prometheus* directive) then the following met * `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream. * `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy, and we are randomly (this always uses the `random` policy) spraying to an upstream. -* `coredns_forward_socket_count_total{to}` - number of cached sockets per upstream. Where `to` is one of the upstream servers (**TO** from the config), `rcode` is the returned RCODE from the upstream. diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go index 8fde2224b..9ac1afe16 100644 --- a/plugin/forward/connect.go +++ b/plugin/forward/connect.go @@ -44,17 +44,17 @@ func (t *Transport) updateDialTimeout(newDialTime time.Duration) { } // Dial dials the address configured in transport, potentially reusing a connection or creating a new one. -func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) { +func (t *Transport) Dial(proto string) (*persistConn, bool, error) { // If tls has been configured; use it. if t.tlsConfig != nil { proto = "tcp-tls" } t.dial <- proto - c := <-t.ret + pc := <-t.ret - if c != nil { - return c, true, nil + if pc != nil { + return pc, true, nil } reqTime := time.Now() @@ -62,11 +62,11 @@ func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) { if proto == "tcp-tls" { conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) t.updateDialTimeout(time.Since(reqTime)) - return conn, false, err + return &persistConn{c: conn}, false, err } conn, err := dns.DialTimeout(proto, t.addr, timeout) t.updateDialTimeout(time.Since(reqTime)) - return conn, false, err + return &persistConn{c: conn}, false, err } // Connect selects an upstream, sends the request and waits for a response. @@ -83,20 +83,20 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options proto = state.Proto() } - conn, cached, err := p.transport.Dial(proto) + pc, cached, err := p.transport.Dial(proto) if err != nil { return nil, err } // Set buffer size correctly for this client. - conn.UDPSize = uint16(state.Size()) - if conn.UDPSize < 512 { - conn.UDPSize = 512 + pc.c.UDPSize = uint16(state.Size()) + if pc.c.UDPSize < 512 { + pc.c.UDPSize = 512 } - conn.SetWriteDeadline(time.Now().Add(maxTimeout)) - if err := conn.WriteMsg(state.Req); err != nil { - conn.Close() // not giving it back + pc.c.SetWriteDeadline(time.Now().Add(maxTimeout)) + if err := pc.c.WriteMsg(state.Req); err != nil { + pc.c.Close() // not giving it back if err == io.EOF && cached { return nil, ErrCachedClosed } @@ -104,11 +104,11 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options } var ret *dns.Msg - conn.SetReadDeadline(time.Now().Add(readTimeout)) + pc.c.SetReadDeadline(time.Now().Add(readTimeout)) for { - ret, err = conn.ReadMsg() + ret, err = pc.c.ReadMsg() if err != nil { - conn.Close() // not giving it back + pc.c.Close() // not giving it back if err == io.EOF && cached { return nil, ErrCachedClosed } @@ -120,7 +120,7 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options } } - p.transport.Yield(conn) + p.transport.Yield(pc) rc, ok := dns.RcodeToString[ret.Rcode] if !ok { diff --git a/plugin/forward/health_test.go b/plugin/forward/health_test.go index 3d06f2835..b1785eeae 100644 --- a/plugin/forward/health_test.go +++ b/plugin/forward/health_test.go @@ -29,7 +29,7 @@ func TestHealth(t *testing.T) { p := NewProxy(s.Addr, transport.DNS) f := New() f.SetProxy(p) - defer f.Close() + defer f.OnShutdown() req := new(dns.Msg) req.SetQuestion("example.org.", dns.TypeA) @@ -69,7 +69,7 @@ func TestHealthTimeout(t *testing.T) { p := NewProxy(s.Addr, transport.DNS) f := New() f.SetProxy(p) - defer f.Close() + defer f.OnShutdown() req := new(dns.Msg) req.SetQuestion("example.org.", dns.TypeA) @@ -113,7 +113,7 @@ func TestHealthFailTwice(t *testing.T) { p := NewProxy(s.Addr, transport.DNS) f := New() f.SetProxy(p) - defer f.Close() + defer f.OnShutdown() req := new(dns.Msg) req.SetQuestion("example.org.", dns.TypeA) @@ -137,7 +137,7 @@ func TestHealthMaxFails(t *testing.T) { f := New() f.maxfails = 2 f.SetProxy(p) - defer f.Close() + defer f.OnShutdown() req := new(dns.Msg) req.SetQuestion("example.org.", dns.TypeA) @@ -169,7 +169,7 @@ func TestHealthNoMaxFails(t *testing.T) { f := New() f.maxfails = 0 f.SetProxy(p) - defer f.Close() + defer f.OnShutdown() req := new(dns.Msg) req.SetQuestion("example.org.", dns.TypeA) diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go index d1348f94d..17cc5680a 100644 --- a/plugin/forward/persistent.go +++ b/plugin/forward/persistent.go @@ -24,8 +24,8 @@ type Transport struct { tlsConfig *tls.Config dial chan string - yield chan *dns.Conn - ret chan *dns.Conn + yield chan *persistConn + ret chan *persistConn stop chan bool } @@ -36,23 +36,13 @@ func newTransport(addr string) *Transport { expire: defaultExpire, addr: addr, dial: make(chan string), - yield: make(chan *dns.Conn), - ret: make(chan *dns.Conn), + yield: make(chan *persistConn), + ret: make(chan *persistConn), stop: make(chan bool), } return t } -// len returns the number of connection, used for metrics. Can only be safely -// used inside connManager() because of data races. -func (t *Transport) len() int { - l := 0 - for _, conns := range t.conns { - l += len(conns) - } - return l -} - // connManagers manages the persistent connection cache for UDP and TCP. func (t *Transport) connManager() { ticker := time.NewTicker(t.expire) @@ -66,7 +56,7 @@ Wait: if time.Since(pc.used) < t.expire { // Found one, remove from pool and return this conn. t.conns[proto] = stack[:len(stack)-1] - t.ret <- pc.c + t.ret <- pc continue Wait } // clear entire cache if the last conn is expired @@ -75,26 +65,21 @@ Wait: // transport methods anymore. So, it's safe to close them in a separate goroutine go closeConns(stack) } - SocketGauge.WithLabelValues(t.addr).Set(float64(t.len())) - t.ret <- nil - case conn := <-t.yield: - - SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1)) - + case pc := <-t.yield: // no proto here, infer from config and conn - if _, ok := conn.Conn.(*net.UDPConn); ok { - t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()}) + if _, ok := pc.c.Conn.(*net.UDPConn); ok { + t.conns["udp"] = append(t.conns["udp"], pc) continue Wait } if t.tlsConfig == nil { - t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()}) + t.conns["tcp"] = append(t.conns["tcp"], pc) continue Wait } - t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) + t.conns["tcp-tls"] = append(t.conns["tcp-tls"], pc) case <-ticker.C: t.cleanup(false) @@ -143,8 +128,23 @@ func (t *Transport) cleanup(all bool) { } } +// It is hard to pin a value to this, the import thing is to no block forever, loosing at cached connection is not terrible. +const yieldTimeout = 25 * time.Millisecond + // Yield return the connection to transport for reuse. -func (t *Transport) Yield(c *dns.Conn) { t.yield <- c } +func (t *Transport) Yield(pc *persistConn) { + pc.used = time.Now() // update used time + + // Make ths non-blocking, because in the case of a very busy forwarder we will *block* on this yield. This + // blocks the outer go-routine and stuff will just pile up. We timeout when the send fails to as returning + // these connection is an optimization anyway. + select { + case t.yield <- pc: + return + case <-time.After(yieldTimeout): + return + } +} // Start starts the transport's connection manager. func (t *Transport) Start() { go t.connManager() } diff --git a/plugin/forward/persistent_test.go b/plugin/forward/persistent_test.go index f1c906076..1fb239ca7 100644 --- a/plugin/forward/persistent_test.go +++ b/plugin/forward/persistent_test.go @@ -82,54 +82,6 @@ func TestCleanupByTimer(t *testing.T) { tr.Yield(c4) } -func TestPartialCleanup(t *testing.T) { - s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { - ret := new(dns.Msg) - ret.SetReply(r) - w.WriteMsg(ret) - }) - defer s.Close() - - tr := newTransport(s.Addr) - tr.SetExpire(100 * time.Millisecond) - tr.Start() - defer tr.Stop() - - c1, _, _ := tr.Dial("udp") - c2, _, _ := tr.Dial("udp") - c3, _, _ := tr.Dial("udp") - c4, _, _ := tr.Dial("udp") - c5, _, _ := tr.Dial("udp") - - tr.Yield(c1) - time.Sleep(10 * time.Millisecond) - tr.Yield(c2) - time.Sleep(10 * time.Millisecond) - tr.Yield(c3) - time.Sleep(50 * time.Millisecond) - tr.Yield(c4) - time.Sleep(10 * time.Millisecond) - tr.Yield(c5) - time.Sleep(40 * time.Millisecond) - - c6, _, _ := tr.Dial("udp") - if c6 != c5 { - t.Errorf("Expected c6 == c5") - } - c7, _, _ := tr.Dial("udp") - if c7 != c4 { - t.Errorf("Expected c7 == c4") - } - c8, cached, _ := tr.Dial("udp") - if cached { - t.Error("Expected non-cached connection (c8)") - } - - tr.Yield(c6) - tr.Yield(c7) - tr.Yield(c8) -} - func TestCleanupAll(t *testing.T) { s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) @@ -150,12 +102,12 @@ func TestCleanupAll(t *testing.T) { {c3, time.Now()}, } - if tr.len() != 3 { + if len(tr.conns["udp"]) != 3 { t.Error("Expected 3 connections") } tr.cleanup(true) - if tr.len() > 0 { + if len(tr.conns["udp"]) > 0 { t.Error("Expected no cached connections") } } diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go index 6485d8d72..60dfa47a6 100644 --- a/plugin/forward/proxy.go +++ b/plugin/forward/proxy.go @@ -66,7 +66,7 @@ func (p *Proxy) Down(maxfails uint32) bool { } // close stops the health checking goroutine. -func (p *Proxy) close() { p.probe.Stop() } +func (p *Proxy) stop() { p.probe.Stop() } func (p *Proxy) finalizer() { p.transport.Stop() } // start starts the proxy's healthchecking. diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go index 7075e1133..b962f561b 100644 --- a/plugin/forward/proxy_test.go +++ b/plugin/forward/proxy_test.go @@ -35,7 +35,7 @@ func TestProxyClose(t *testing.T) { go func() { p.Connect(ctx, state, options{}) }() go func() { p.Connect(ctx, state, options{forceTCP: true}) }() - p.close() + p.stop() } } diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 3bfee9f01..d45756693 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -54,14 +54,11 @@ func (f *Forward) OnStartup() (err error) { // OnShutdown stops all configured proxies. func (f *Forward) OnShutdown() error { for _, p := range f.proxies { - p.close() + p.stop() } return nil } -// Close is a synonym for OnShutdown(). -func (f *Forward) Close() { f.OnShutdown() } - func parseForward(c *caddy.Controller) (*Forward, error) { var ( f *Forward