plugin/forward: move Dial goroutine out (#1738)
Rework the TestProxyClose - close the proxy in the *same* goroutine as where we started it. Close channels as long as we don't get dataraces (this may need another fix). Move the Dial goroutine out of the connManager - this simplifies things *and* makes another goroutine go away and removes the need for connErr channels - can now just be dns.Conn. Also: Revert "plugin/forward: gracefull stop (#1701)" This reverts commit135377bf77
. Revert "rework TestProxyClose (#1735)" This reverts commit9e8893a0b5
.
This commit is contained in:
parent
4c7ae4ea95
commit
270da82995
5 changed files with 36 additions and 117 deletions
|
@ -34,16 +34,6 @@ func (p *Proxy) updateRtt(newRtt time.Duration) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
|
func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
|
||||||
atomic.AddInt32(&p.inProgress, 1)
|
|
||||||
defer func() {
|
|
||||||
if atomic.AddInt32(&p.inProgress, -1) == 0 {
|
|
||||||
p.checkStopTransport()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if atomic.LoadUint32(&p.state) != running {
|
|
||||||
return nil, errStopped
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
proto := state.Proto()
|
proto := state.Proto()
|
||||||
|
@ -55,6 +45,7 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set buffer size correctly for this client.
|
// Set buffer size correctly for this client.
|
||||||
conn.UDPSize = uint16(state.Size())
|
conn.UDPSize = uint16(state.Size())
|
||||||
if conn.UDPSize < 512 {
|
if conn.UDPSize < 512 {
|
||||||
|
|
|
@ -119,7 +119,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Kick off health check to see if *our* upstream is broken.
|
// Kick off health check to see if *our* upstream is broken.
|
||||||
if f.maxfails != 0 && err != errStopped {
|
if f.maxfails != 0 {
|
||||||
proxy.Healthcheck()
|
proxy.Healthcheck()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,7 +185,6 @@ var (
|
||||||
errNoHealthy = errors.New("no healthy proxies")
|
errNoHealthy = errors.New("no healthy proxies")
|
||||||
errNoForward = errors.New("no forwarder defined")
|
errNoForward = errors.New("no forwarder defined")
|
||||||
errCachedClosed = errors.New("cached connection was closed by peer")
|
errCachedClosed = errors.New("cached connection was closed by peer")
|
||||||
errStopped = errors.New("proxy has been stopped")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// policy tells forward what policy for selecting upstream it uses.
|
// policy tells forward what policy for selecting upstream it uses.
|
||||||
|
|
|
@ -14,13 +14,6 @@ type persistConn struct {
|
||||||
used time.Time
|
used time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// connErr is used to communicate the connection manager.
|
|
||||||
type connErr struct {
|
|
||||||
c *dns.Conn
|
|
||||||
err error
|
|
||||||
cached bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// transport hold the persistent cache.
|
// transport hold the persistent cache.
|
||||||
type transport struct {
|
type transport struct {
|
||||||
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
||||||
|
@ -29,8 +22,8 @@ type transport struct {
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
|
|
||||||
dial chan string
|
dial chan string
|
||||||
yield chan connErr
|
yield chan *dns.Conn
|
||||||
ret chan connErr
|
ret chan *dns.Conn
|
||||||
stop chan bool
|
stop chan bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,18 +33,11 @@ func newTransport(addr string, tlsConfig *tls.Config) *transport {
|
||||||
expire: defaultExpire,
|
expire: defaultExpire,
|
||||||
addr: addr,
|
addr: addr,
|
||||||
dial: make(chan string),
|
dial: make(chan string),
|
||||||
yield: make(chan connErr),
|
yield: make(chan *dns.Conn),
|
||||||
ret: make(chan connErr),
|
ret: make(chan *dns.Conn),
|
||||||
stop: make(chan bool),
|
stop: make(chan bool),
|
||||||
}
|
}
|
||||||
go func() {
|
go func() { t.connManager() }()
|
||||||
t.connManager()
|
|
||||||
// if connManager returns it has been stopped.
|
|
||||||
close(t.stop)
|
|
||||||
close(t.yield)
|
|
||||||
close(t.dial)
|
|
||||||
// close(t.ret) // we can still be dialing and wanting to send back the socket on t.ret
|
|
||||||
}()
|
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +66,7 @@ Wait:
|
||||||
if time.Since(pc.used) < t.expire {
|
if time.Since(pc.used) < t.expire {
|
||||||
// Found one, remove from pool and return this conn.
|
// Found one, remove from pool and return this conn.
|
||||||
t.conns[proto] = t.conns[proto][i+1:]
|
t.conns[proto] = t.conns[proto][i+1:]
|
||||||
t.ret <- connErr{pc.c, nil, true}
|
t.ret <- pc.c
|
||||||
continue Wait
|
continue Wait
|
||||||
}
|
}
|
||||||
// This conn has expired. Close it.
|
// This conn has expired. Close it.
|
||||||
|
@ -91,35 +77,27 @@ Wait:
|
||||||
t.conns[proto] = t.conns[proto][i:]
|
t.conns[proto] = t.conns[proto][i:]
|
||||||
SocketGauge.WithLabelValues(t.addr).Set(float64(t.len()))
|
SocketGauge.WithLabelValues(t.addr).Set(float64(t.len()))
|
||||||
|
|
||||||
go func() {
|
t.ret <- nil
|
||||||
if proto != "tcp-tls" {
|
|
||||||
c, err := dns.DialTimeout(proto, t.addr, dialTimeout)
|
|
||||||
t.ret <- connErr{c, err, false}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout)
|
|
||||||
t.ret <- connErr{c, err, false}
|
|
||||||
}()
|
|
||||||
|
|
||||||
case conn := <-t.yield:
|
case conn := <-t.yield:
|
||||||
|
|
||||||
SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1))
|
SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1))
|
||||||
|
|
||||||
// no proto here, infer from config and conn
|
// no proto here, infer from config and conn
|
||||||
if _, ok := conn.c.Conn.(*net.UDPConn); ok {
|
if _, ok := conn.Conn.(*net.UDPConn); ok {
|
||||||
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn.c, time.Now()})
|
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
|
||||||
continue Wait
|
continue Wait
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.tlsConfig == nil {
|
if t.tlsConfig == nil {
|
||||||
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()})
|
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
|
||||||
continue Wait
|
continue Wait
|
||||||
}
|
}
|
||||||
|
|
||||||
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn.c, time.Now()})
|
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
|
||||||
|
|
||||||
case <-t.stop:
|
case <-t.stop:
|
||||||
|
close(t.ret)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -134,16 +112,24 @@ func (t *transport) Dial(proto string) (*dns.Conn, bool, error) {
|
||||||
|
|
||||||
t.dial <- proto
|
t.dial <- proto
|
||||||
c := <-t.ret
|
c := <-t.ret
|
||||||
return c.c, c.cached, c.err
|
|
||||||
|
if c != nil {
|
||||||
|
return c, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if proto == "tcp-tls" {
|
||||||
|
conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout)
|
||||||
|
return conn, false, err
|
||||||
|
}
|
||||||
|
conn, err := dns.DialTimeout(proto, t.addr, dialTimeout)
|
||||||
|
return conn, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Yield return the connection to transport for reuse.
|
// Yield return the connection to transport for reuse.
|
||||||
func (t *transport) Yield(c *dns.Conn) {
|
func (t *transport) Yield(c *dns.Conn) { t.yield <- c }
|
||||||
t.yield <- connErr{c, nil, false}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the transport's connection manager.
|
// Stop stops the transport's connection manager.
|
||||||
func (t *transport) Stop() { t.stop <- true }
|
func (t *transport) Stop() { close(t.stop) }
|
||||||
|
|
||||||
// SetExpire sets the connection expire time in transport.
|
// SetExpire sets the connection expire time in transport.
|
||||||
func (t *transport) SetExpire(expire time.Duration) { t.expire = expire }
|
func (t *transport) SetExpire(expire time.Duration) { t.expire = expire }
|
||||||
|
|
|
@ -24,9 +24,6 @@ type Proxy struct {
|
||||||
fails uint32
|
fails uint32
|
||||||
|
|
||||||
avgRtt int64
|
avgRtt int64
|
||||||
|
|
||||||
state uint32
|
|
||||||
inProgress int32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewProxy returns a new proxy.
|
// NewProxy returns a new proxy.
|
||||||
|
@ -85,26 +82,15 @@ func (p *Proxy) Down(maxfails uint32) bool {
|
||||||
return fails > maxfails
|
return fails > maxfails
|
||||||
}
|
}
|
||||||
|
|
||||||
// close stops the health checking goroutine and connection manager.
|
// close stops the health checking goroutine.
|
||||||
func (p *Proxy) close() {
|
func (p *Proxy) close() {
|
||||||
if atomic.CompareAndSwapUint32(&p.state, running, stopping) {
|
p.probe.Stop()
|
||||||
p.probe.Stop()
|
p.transport.Stop()
|
||||||
}
|
|
||||||
if atomic.LoadInt32(&p.inProgress) == 0 {
|
|
||||||
p.checkStopTransport()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// start starts the proxy's healthchecking.
|
// start starts the proxy's healthchecking.
|
||||||
func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) }
|
func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) }
|
||||||
|
|
||||||
// checkStopTransport checks if stop was requested and stops connection manager
|
|
||||||
func (p *Proxy) checkStopTransport() {
|
|
||||||
if atomic.CompareAndSwapUint32(&p.state, stopping, stopped) {
|
|
||||||
p.transport.Stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dialTimeout = 4 * time.Second
|
dialTimeout = 4 * time.Second
|
||||||
timeout = 2 * time.Second
|
timeout = 2 * time.Second
|
||||||
|
@ -112,9 +98,3 @@ const (
|
||||||
minTimeout = 10 * time.Millisecond
|
minTimeout = 10 * time.Millisecond
|
||||||
hcDuration = 500 * time.Millisecond
|
hcDuration = 500 * time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
running = iota
|
|
||||||
stopping
|
|
||||||
stopped
|
|
||||||
)
|
|
||||||
|
|
|
@ -2,9 +2,7 @@ package forward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"runtime"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||||
"github.com/coredns/coredns/plugin/test"
|
"github.com/coredns/coredns/plugin/test"
|
||||||
|
@ -28,50 +26,15 @@ func TestProxyClose(t *testing.T) {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
p := NewProxy(s.Addr, nil /* no TLS */)
|
p := NewProxy(s.Addr, nil)
|
||||||
p.start(hcDuration)
|
p.start(hcDuration)
|
||||||
|
|
||||||
doneCnt := 0
|
go func() { p.connect(ctx, state, false, false) }()
|
||||||
doneCh := make(chan bool)
|
go func() { p.connect(ctx, state, true, false) }()
|
||||||
timeCh := time.After(10 * time.Second)
|
go func() { p.connect(ctx, state, false, false) }()
|
||||||
go func() {
|
go func() { p.connect(ctx, state, true, false) }()
|
||||||
p.connect(ctx, state, false, false)
|
|
||||||
doneCh <- true
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
p.connect(ctx, state, true, false)
|
|
||||||
doneCh <- true
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
p.close()
|
|
||||||
doneCh <- true
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
p.connect(ctx, state, false, false)
|
|
||||||
doneCh <- true
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
p.connect(ctx, state, true, false)
|
|
||||||
doneCh <- true
|
|
||||||
}()
|
|
||||||
|
|
||||||
for doneCnt < 5 {
|
p.close()
|
||||||
select {
|
|
||||||
case <-doneCh:
|
|
||||||
doneCnt++
|
|
||||||
case <-timeCh:
|
|
||||||
t.Error("TestProxyClose is running too long, dumping goroutines:")
|
|
||||||
buf := make([]byte, 100000)
|
|
||||||
stackSize := runtime.Stack(buf, true)
|
|
||||||
t.Fatal(string(buf[:stackSize]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if p.inProgress != 0 {
|
|
||||||
t.Errorf("unexpected query in progress")
|
|
||||||
}
|
|
||||||
if p.state != stopped {
|
|
||||||
t.Errorf("unexpected proxy state, expected %d, got %d", stopped, p.state)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue