plugin/forward: move Dial goroutine out (#1738)

Rework the TestProxyClose - close the proxy in the *same* goroutine
as where we started it. Close channels as long as we don't get dataraces
(this may need another fix).

Move the Dial goroutine out of the connManager - this simplifies things
*and* makes another goroutine go away and removes the need for connErr
channels - can now just be dns.Conn.

Also:

Revert "plugin/forward: gracefull stop (#1701)"
This reverts commit 135377bf77.

Revert "rework TestProxyClose (#1735)"
This reverts commit 9e8893a0b5.
This commit is contained in:
Miek Gieben 2018-04-26 09:34:58 +01:00 committed by GitHub
parent 4c7ae4ea95
commit 270da82995
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 117 deletions

View file

@ -34,16 +34,6 @@ func (p *Proxy) updateRtt(newRtt time.Duration) {
} }
func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) { func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
atomic.AddInt32(&p.inProgress, 1)
defer func() {
if atomic.AddInt32(&p.inProgress, -1) == 0 {
p.checkStopTransport()
}
}()
if atomic.LoadUint32(&p.state) != running {
return nil, errStopped
}
start := time.Now() start := time.Now()
proto := state.Proto() proto := state.Proto()
@ -55,6 +45,7 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Set buffer size correctly for this client. // Set buffer size correctly for this client.
conn.UDPSize = uint16(state.Size()) conn.UDPSize = uint16(state.Size())
if conn.UDPSize < 512 { if conn.UDPSize < 512 {

View file

@ -119,7 +119,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
if err != nil { if err != nil {
// Kick off health check to see if *our* upstream is broken. // Kick off health check to see if *our* upstream is broken.
if f.maxfails != 0 && err != errStopped { if f.maxfails != 0 {
proxy.Healthcheck() proxy.Healthcheck()
} }
@ -185,7 +185,6 @@ var (
errNoHealthy = errors.New("no healthy proxies") errNoHealthy = errors.New("no healthy proxies")
errNoForward = errors.New("no forwarder defined") errNoForward = errors.New("no forwarder defined")
errCachedClosed = errors.New("cached connection was closed by peer") errCachedClosed = errors.New("cached connection was closed by peer")
errStopped = errors.New("proxy has been stopped")
) )
// policy tells forward what policy for selecting upstream it uses. // policy tells forward what policy for selecting upstream it uses.

View file

@ -14,13 +14,6 @@ type persistConn struct {
used time.Time used time.Time
} }
// connErr is used to communicate the connection manager.
type connErr struct {
c *dns.Conn
err error
cached bool
}
// transport hold the persistent cache. // transport hold the persistent cache.
type transport struct { type transport struct {
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
@ -29,8 +22,8 @@ type transport struct {
tlsConfig *tls.Config tlsConfig *tls.Config
dial chan string dial chan string
yield chan connErr yield chan *dns.Conn
ret chan connErr ret chan *dns.Conn
stop chan bool stop chan bool
} }
@ -40,18 +33,11 @@ func newTransport(addr string, tlsConfig *tls.Config) *transport {
expire: defaultExpire, expire: defaultExpire,
addr: addr, addr: addr,
dial: make(chan string), dial: make(chan string),
yield: make(chan connErr), yield: make(chan *dns.Conn),
ret: make(chan connErr), ret: make(chan *dns.Conn),
stop: make(chan bool), stop: make(chan bool),
} }
go func() { go func() { t.connManager() }()
t.connManager()
// if connManager returns it has been stopped.
close(t.stop)
close(t.yield)
close(t.dial)
// close(t.ret) // we can still be dialing and wanting to send back the socket on t.ret
}()
return t return t
} }
@ -80,7 +66,7 @@ Wait:
if time.Since(pc.used) < t.expire { if time.Since(pc.used) < t.expire {
// Found one, remove from pool and return this conn. // Found one, remove from pool and return this conn.
t.conns[proto] = t.conns[proto][i+1:] t.conns[proto] = t.conns[proto][i+1:]
t.ret <- connErr{pc.c, nil, true} t.ret <- pc.c
continue Wait continue Wait
} }
// This conn has expired. Close it. // This conn has expired. Close it.
@ -91,35 +77,27 @@ Wait:
t.conns[proto] = t.conns[proto][i:] t.conns[proto] = t.conns[proto][i:]
SocketGauge.WithLabelValues(t.addr).Set(float64(t.len())) SocketGauge.WithLabelValues(t.addr).Set(float64(t.len()))
go func() { t.ret <- nil
if proto != "tcp-tls" {
c, err := dns.DialTimeout(proto, t.addr, dialTimeout)
t.ret <- connErr{c, err, false}
return
}
c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout)
t.ret <- connErr{c, err, false}
}()
case conn := <-t.yield: case conn := <-t.yield:
SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1)) SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1))
// no proto here, infer from config and conn // no proto here, infer from config and conn
if _, ok := conn.c.Conn.(*net.UDPConn); ok { if _, ok := conn.Conn.(*net.UDPConn); ok {
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn.c, time.Now()}) t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
continue Wait continue Wait
} }
if t.tlsConfig == nil { if t.tlsConfig == nil {
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()}) t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
continue Wait continue Wait
} }
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn.c, time.Now()}) t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
case <-t.stop: case <-t.stop:
close(t.ret)
return return
} }
} }
@ -134,16 +112,24 @@ func (t *transport) Dial(proto string) (*dns.Conn, bool, error) {
t.dial <- proto t.dial <- proto
c := <-t.ret c := <-t.ret
return c.c, c.cached, c.err
if c != nil {
return c, true, nil
}
if proto == "tcp-tls" {
conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout)
return conn, false, err
}
conn, err := dns.DialTimeout(proto, t.addr, dialTimeout)
return conn, false, err
} }
// Yield return the connection to transport for reuse. // Yield return the connection to transport for reuse.
func (t *transport) Yield(c *dns.Conn) { func (t *transport) Yield(c *dns.Conn) { t.yield <- c }
t.yield <- connErr{c, nil, false}
}
// Stop stops the transport's connection manager. // Stop stops the transport's connection manager.
func (t *transport) Stop() { t.stop <- true } func (t *transport) Stop() { close(t.stop) }
// SetExpire sets the connection expire time in transport. // SetExpire sets the connection expire time in transport.
func (t *transport) SetExpire(expire time.Duration) { t.expire = expire } func (t *transport) SetExpire(expire time.Duration) { t.expire = expire }

View file

@ -24,9 +24,6 @@ type Proxy struct {
fails uint32 fails uint32
avgRtt int64 avgRtt int64
state uint32
inProgress int32
} }
// NewProxy returns a new proxy. // NewProxy returns a new proxy.
@ -85,26 +82,15 @@ func (p *Proxy) Down(maxfails uint32) bool {
return fails > maxfails return fails > maxfails
} }
// close stops the health checking goroutine and connection manager. // close stops the health checking goroutine.
func (p *Proxy) close() { func (p *Proxy) close() {
if atomic.CompareAndSwapUint32(&p.state, running, stopping) { p.probe.Stop()
p.probe.Stop() p.transport.Stop()
}
if atomic.LoadInt32(&p.inProgress) == 0 {
p.checkStopTransport()
}
} }
// start starts the proxy's healthchecking. // start starts the proxy's healthchecking.
func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) } func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) }
// checkStopTransport checks if stop was requested and stops connection manager
func (p *Proxy) checkStopTransport() {
if atomic.CompareAndSwapUint32(&p.state, stopping, stopped) {
p.transport.Stop()
}
}
const ( const (
dialTimeout = 4 * time.Second dialTimeout = 4 * time.Second
timeout = 2 * time.Second timeout = 2 * time.Second
@ -112,9 +98,3 @@ const (
minTimeout = 10 * time.Millisecond minTimeout = 10 * time.Millisecond
hcDuration = 500 * time.Millisecond hcDuration = 500 * time.Millisecond
) )
const (
running = iota
stopping
stopped
)

View file

@ -2,9 +2,7 @@ package forward
import ( import (
"context" "context"
"runtime"
"testing" "testing"
"time"
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
@ -28,50 +26,15 @@ func TestProxyClose(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
p := NewProxy(s.Addr, nil /* no TLS */) p := NewProxy(s.Addr, nil)
p.start(hcDuration) p.start(hcDuration)
doneCnt := 0 go func() { p.connect(ctx, state, false, false) }()
doneCh := make(chan bool) go func() { p.connect(ctx, state, true, false) }()
timeCh := time.After(10 * time.Second) go func() { p.connect(ctx, state, false, false) }()
go func() { go func() { p.connect(ctx, state, true, false) }()
p.connect(ctx, state, false, false)
doneCh <- true
}()
go func() {
p.connect(ctx, state, true, false)
doneCh <- true
}()
go func() {
p.close()
doneCh <- true
}()
go func() {
p.connect(ctx, state, false, false)
doneCh <- true
}()
go func() {
p.connect(ctx, state, true, false)
doneCh <- true
}()
for doneCnt < 5 { p.close()
select {
case <-doneCh:
doneCnt++
case <-timeCh:
t.Error("TestProxyClose is running too long, dumping goroutines:")
buf := make([]byte, 100000)
stackSize := runtime.Stack(buf, true)
t.Fatal(string(buf[:stackSize]))
}
}
if p.inProgress != 0 {
t.Errorf("unexpected query in progress")
}
if p.state != stopped {
t.Errorf("unexpected proxy state, expected %d, got %d", stopped, p.state)
}
} }
} }