middleware/proxy: Make Unhealthy a pointer (#615)

Pointer updates are atomic so drop the sync.RWMutex as it is not needed
anymore. This also fixes the race introduced with dfc71df (although I
believe this is the first time we properly tested that code path).
This commit is contained in:
Miek Gieben 2017-04-13 16:26:05 +01:00 committed by GitHub
parent ef4fa66e67
commit acbf522ceb
6 changed files with 28 additions and 24 deletions

View file

@ -218,11 +218,11 @@ func newUpstream(hosts []string, old *staticUpstream) Upstream {
Conns: 0, Conns: 0,
Fails: 0, Fails: 0,
FailTimeout: upstream.FailTimeout, FailTimeout: upstream.FailTimeout,
Unhealthy: false, Unhealthy: newBool(),
CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool { return func(uh *UpstreamHost) bool {
if uh.Unhealthy { if *uh.Unhealthy {
return true return true
} }

View file

@ -38,10 +38,10 @@ func NewLookupWithOption(hosts []string, opts Options) Proxy {
Fails: 0, Fails: 0,
FailTimeout: upstream.FailTimeout, FailTimeout: upstream.FailTimeout,
Unhealthy: false, Unhealthy: newBool(),
CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool { return func(uh *UpstreamHost) bool {
if uh.Unhealthy { if *uh.Unhealthy {
return true return true
} }
fails := atomic.LoadInt32(&uh.Fails) fails := atomic.LoadInt32(&uh.Fails)

View file

@ -29,12 +29,15 @@ func testPool() HostPool {
pool := []*UpstreamHost{ pool := []*UpstreamHost{
{ {
Name: workableServer.URL, // this should resolve (healthcheck test) Name: workableServer.URL, // this should resolve (healthcheck test)
Unhealthy: newBool(),
}, },
{ {
Name: "http://shouldnot.resolve", // this shouldn't Name: "http://shouldnot.resolve", // this shouldn't
Unhealthy: newBool(),
}, },
{ {
Name: "http://C", Name: "http://C",
Unhealthy: newBool(),
}, },
} }
return HostPool(pool) return HostPool(pool)
@ -54,7 +57,7 @@ func TestRoundRobinPolicy(t *testing.T) {
t.Error("Expected second round robin host to be third host in the pool.") t.Error("Expected second round robin host to be third host in the pool.")
} }
// mark host as down // mark host as down
pool[0].Unhealthy = true *pool[0].Unhealthy = true
h = rrPolicy.Select(pool) h = rrPolicy.Select(pool)
if h != pool[1] { if h != pool[1] {
t.Error("Expected third round robin host to be first host in the pool.") t.Error("Expected third round robin host to be first host in the pool.")

View file

@ -3,7 +3,6 @@ package proxy
import ( import (
"errors" "errors"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -57,10 +56,9 @@ type UpstreamHost struct {
Name string // IP address (and port) of this upstream host Name string // IP address (and port) of this upstream host
Fails int32 Fails int32
FailTimeout time.Duration FailTimeout time.Duration
Unhealthy bool Unhealthy *bool
CheckDown UpstreamHostDownFunc CheckDown UpstreamHostDownFunc
WithoutPathPrefix string WithoutPathPrefix string
checkMu sync.Mutex
} }
// Down checks whether the upstream host is down or not. // Down checks whether the upstream host is down or not.
@ -70,7 +68,7 @@ func (uh *UpstreamHost) Down() bool {
if uh.CheckDown == nil { if uh.CheckDown == nil {
// Default settings // Default settings
fails := atomic.LoadInt32(&uh.Fails) fails := atomic.LoadInt32(&uh.Fails)
return uh.Unhealthy || fails > 0 return *uh.Unhealthy || fails > 0
} }
return uh.CheckDown(uh) return uh.CheckDown(uh)
} }

View file

@ -84,11 +84,11 @@ func NewStaticUpstreams(c *caddyfile.Dispenser) ([]Upstream, error) {
Conns: 0, Conns: 0,
Fails: 0, Fails: 0,
FailTimeout: upstream.FailTimeout, FailTimeout: upstream.FailTimeout,
Unhealthy: false, Unhealthy: newBool(),
CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool { return func(uh *UpstreamHost) bool {
if uh.Unhealthy { if *uh.Unhealthy {
return true return true
} }
@ -251,22 +251,19 @@ func (u *staticUpstream) healthCheck() {
hostURL := "http://" + net.JoinHostPort(checkHostName, checkPort) + u.HealthCheck.Path hostURL := "http://" + net.JoinHostPort(checkHostName, checkPort) + u.HealthCheck.Path
host.checkMu.Lock()
defer host.checkMu.Unlock()
if r, err := http.Get(hostURL); err == nil { if r, err := http.Get(hostURL); err == nil {
io.Copy(ioutil.Discard, r.Body) io.Copy(ioutil.Discard, r.Body)
r.Body.Close() r.Body.Close()
if r.StatusCode < 200 || r.StatusCode >= 400 { if r.StatusCode < 200 || r.StatusCode >= 400 {
log.Printf("[WARNING] Health check URL %s returned HTTP code %d\n", log.Printf("[WARNING] Health check URL %s returned HTTP code %d\n",
hostURL, r.StatusCode) hostURL, r.StatusCode)
host.Unhealthy = true *host.Unhealthy = true
} else { } else {
host.Unhealthy = false *host.Unhealthy = false
} }
} else { } else {
log.Printf("[WARNING] Health check probe failed: %v\n", err) log.Printf("[WARNING] Health check probe failed: %v\n", err)
host.Unhealthy = true *host.Unhealthy = true
} }
} }
} }
@ -341,3 +338,9 @@ func (u *staticUpstream) IsAllowedDomain(name string) bool {
} }
func (u *staticUpstream) Exchanger() Exchanger { return u.ex } func (u *staticUpstream) Exchanger() Exchanger { return u.ex }
func newBool() *bool {
b := new(bool)
*b = false
return b
}

View file

@ -42,13 +42,13 @@ func TestSelect(t *testing.T) {
FailTimeout: 10 * time.Second, FailTimeout: 10 * time.Second,
MaxFails: 1, MaxFails: 1,
} }
upstream.Hosts[0].Unhealthy = true *upstream.Hosts[0].Unhealthy = true
upstream.Hosts[1].Unhealthy = true *upstream.Hosts[1].Unhealthy = true
upstream.Hosts[2].Unhealthy = true *upstream.Hosts[2].Unhealthy = true
if h := upstream.Select(); h != nil { if h := upstream.Select(); h != nil {
t.Error("Expected select to return nil as all host are down") t.Error("Expected select to return nil as all host are down")
} }
upstream.Hosts[2].Unhealthy = false *upstream.Hosts[2].Unhealthy = false
if h := upstream.Select(); h == nil { if h := upstream.Select(); h == nil {
t.Error("Expected select to not return nil") t.Error("Expected select to not return nil")
} }