diff --git a/plugin/forward/README.md b/plugin/forward/README.md index f9c2f482b..cc1845377 100644 --- a/plugin/forward/README.md +++ b/plugin/forward/README.md @@ -50,6 +50,7 @@ forward FROM TO... { tls_servername NAME policy random|round_robin|sequential health_check DURATION + max_queries MAX } ~~~ @@ -83,6 +84,11 @@ forward FROM TO... { * `round_robin` is a policy that selects hosts based on round robin ordering. * `sequential` is a policy that selects hosts based on sequential ordering. * `health_check`, use a different **DURATION** for health checking, the default duration is 0.5s. +* `max_concurrent` **MAX** will limit the number of concurrent queries to **MAX**. Any new query that would + raise the number of concurrent queries above the **MAX** will result in a SERVFAIL response. This + response does not count as a health failure. When choosing a value for **MAX**, pick a number + at least greater than the expected *upstream query rate* * *latency* of the upstream servers. + As an upper bound for **MAX**, consider that each concurrent query will use about 2kb of memory. Also note the TLS config is "global" for the whole forwarding proxy if you need a different `tls-name` for different upstreams you're out of luck. @@ -102,7 +108,8 @@ If monitoring is enabled (via the *prometheus* plugin) then the following metric * `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream. * `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy, and we are randomly (this always uses the `random` policy) spraying to an upstream. - +* `max_concurrent_reject_count_total{}` - counter of the number of queries rejected because the + number of concurrent queries were at maximum. Where `to` is one of the upstream servers (**TO** from the config), `rcode` is the returned RCODE from the upstream. diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index 6631b7bab..f6dd939e9 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -8,6 +8,7 @@ import ( "context" "crypto/tls" "errors" + "sync/atomic" "time" "github.com/coredns/coredns/plugin" @@ -25,6 +26,8 @@ var log = clog.NewWithPlugin("forward") // Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list // of proxies each representing one upstream proxy. type Forward struct { + concurrent int64 // atomic counters need to be first in struct for proper alignment + proxies []*Proxy p policy.Policy hcInterval time.Duration @@ -36,9 +39,14 @@ type Forward struct { tlsServerName string maxfails uint32 expire time.Duration + maxConcurrent int64 opts options // also here for testing + // ErrLimitExceeded indicates that a query was rejected because the number of concurrent queries has exceeded + // the maximum allowed (maxConcurrent) + ErrLimitExceeded error + Next plugin.Handler } @@ -68,6 +76,15 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r) } + if f.maxConcurrent > 0 { + count := atomic.AddInt64(&(f.concurrent), 1) + defer atomic.AddInt64(&(f.concurrent), -1) + if count > f.maxConcurrent { + MaxConcurrentRejectCount.Add(1) + return dns.RcodeServerFailure, f.ErrLimitExceeded + } + } + fails := 0 var span, child ot.Span var upstreamErr error diff --git a/plugin/forward/metrics.go b/plugin/forward/metrics.go index e120f55fc..d92028d24 100644 --- a/plugin/forward/metrics.go +++ b/plugin/forward/metrics.go @@ -45,4 +45,10 @@ var ( Name: "sockets_open", Help: "Gauge of open sockets per upstream.", }, []string{"to"}) + MaxConcurrentRejectCount = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "max_concurrent_reject_count_total", + Help: "Counter of the number of queries rejected because the concurrent queries were at maximum.", + }) ) diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index fa35639f2..dadf535c7 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -1,6 +1,7 @@ package forward import ( + "errors" "fmt" "strconv" "time" @@ -121,6 +122,7 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { } f.proxies[i].SetExpire(f.expire) } + return f, nil } @@ -211,6 +213,19 @@ func parseBlock(c *caddy.Controller, f *Forward) error { default: return c.Errf("unknown policy '%s'", x) } + case "max_concurrent": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.Atoi(c.Val()) + if err != nil { + return err + } + if n < 0 { + return fmt.Errorf("max_concurrent can't be negative: %d", n) + } + f.ErrLimitExceeded = errors.New("concurrent queries exceeded maximum " + c.Val()) + f.maxConcurrent = int64(n) default: return c.Errf("unknown property '%s'", c.Val()) diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go index ae0c991d9..c2c2f6759 100644 --- a/plugin/forward/setup_test.go +++ b/plugin/forward/setup_test.go @@ -168,7 +168,45 @@ nameserver 10.10.255.253`), 0666); err != nil { } } for _, p := range f.proxies { - p.health.Check(p) // this should almost always err, we don't care it shoulnd't crash + p.health.Check(p) // this should almost always err, we don't care it shouldn't crash + } + } +} + +func TestSetupMaxConcurrent(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedVal int64 + expectedErr string + }{ + // positive + {"forward . 127.0.0.1 {\nmax_concurrent 1000\n}\n", false, 1000, ""}, + // negative + {"forward . 127.0.0.1 {\nmax_concurrent many\n}\n", true, 0, "invalid"}, + {"forward . 127.0.0.1 {\nmax_concurrent -4\n}\n", true, 0, "negative"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + f, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr && f.maxConcurrent != test.expectedVal { + t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedVal, f.maxConcurrent) } } }