plugin/cache: Fix cache poisoning exploit (#5174)
This commit is contained in:
parent
5a4437bb23
commit
c4bc1a5471
3 changed files with 29 additions and 25 deletions
12
plugin/cache/cache_test.go
vendored
12
plugin/cache/cache_test.go
vendored
|
@ -191,7 +191,7 @@ func TestCache(t *testing.T) {
|
||||||
|
|
||||||
c, crr := newTestCache(maxTTL)
|
c, crr := newTestCache(maxTTL)
|
||||||
|
|
||||||
for _, tc := range cacheTestCases {
|
for n, tc := range cacheTestCases {
|
||||||
m := tc.in.Msg()
|
m := tc.in.Msg()
|
||||||
m = cacheMsg(m, tc)
|
m = cacheMsg(m, tc)
|
||||||
|
|
||||||
|
@ -204,11 +204,15 @@ func TestCache(t *testing.T) {
|
||||||
crr.set(m, k, mt, c.pttl)
|
crr.set(m, k, mt, c.pttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
i, _ := c.get(time.Now().UTC(), state, "dns://:53")
|
i := c.getIgnoreTTL(time.Now().UTC(), state, "dns://:53")
|
||||||
ok := i != nil
|
ok := i != nil
|
||||||
|
|
||||||
if ok != tc.shouldCache {
|
if !tc.shouldCache && ok {
|
||||||
t.Errorf("Cached message that should not have been cached: %s", state.Name())
|
t.Errorf("Test %d: Cached message that should not have been cached: %s", n, state.Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if tc.shouldCache && !ok {
|
||||||
|
t.Errorf("Test %d: Did not cache message that should have been cached: %s", n, state.Name())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
27
plugin/cache/handler.go
vendored
27
plugin/cache/handler.go
vendored
|
@ -89,38 +89,23 @@ func (c *Cache) shouldPrefetch(i *item, now time.Time) bool {
|
||||||
// Name implements the Handler interface.
|
// Name implements the Handler interface.
|
||||||
func (c *Cache) Name() string { return "cache" }
|
func (c *Cache) Name() string { return "cache" }
|
||||||
|
|
||||||
func (c *Cache) get(now time.Time, state request.Request, server string) (*item, bool) {
|
|
||||||
k := hash(state.Name(), state.QType())
|
|
||||||
cacheRequests.WithLabelValues(server, c.zonesMetricLabel).Inc()
|
|
||||||
|
|
||||||
if i, ok := c.ncache.Get(k); ok && i.(*item).ttl(now) > 0 {
|
|
||||||
cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel).Inc()
|
|
||||||
return i.(*item), true
|
|
||||||
}
|
|
||||||
|
|
||||||
if i, ok := c.pcache.Get(k); ok && i.(*item).ttl(now) > 0 {
|
|
||||||
cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel).Inc()
|
|
||||||
return i.(*item), true
|
|
||||||
}
|
|
||||||
cacheMisses.WithLabelValues(server, c.zonesMetricLabel).Inc()
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// getIgnoreTTL unconditionally returns an item if it exists in the cache.
|
// getIgnoreTTL unconditionally returns an item if it exists in the cache.
|
||||||
func (c *Cache) getIgnoreTTL(now time.Time, state request.Request, server string) *item {
|
func (c *Cache) getIgnoreTTL(now time.Time, state request.Request, server string) *item {
|
||||||
k := hash(state.Name(), state.QType())
|
k := hash(state.Name(), state.QType())
|
||||||
cacheRequests.WithLabelValues(server, c.zonesMetricLabel).Inc()
|
cacheRequests.WithLabelValues(server, c.zonesMetricLabel).Inc()
|
||||||
|
|
||||||
if i, ok := c.ncache.Get(k); ok {
|
if i, ok := c.ncache.Get(k); ok {
|
||||||
ttl := i.(*item).ttl(now)
|
itm := i.(*item)
|
||||||
if ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds())) {
|
ttl := itm.ttl(now)
|
||||||
|
if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) {
|
||||||
cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel).Inc()
|
cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel).Inc()
|
||||||
return i.(*item)
|
return i.(*item)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if i, ok := c.pcache.Get(k); ok {
|
if i, ok := c.pcache.Get(k); ok {
|
||||||
ttl := i.(*item).ttl(now)
|
itm := i.(*item)
|
||||||
if ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds())) {
|
ttl := itm.ttl(now)
|
||||||
|
if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) {
|
||||||
cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel).Inc()
|
cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel).Inc()
|
||||||
return i.(*item)
|
return i.(*item)
|
||||||
}
|
}
|
||||||
|
|
15
plugin/cache/item.go
vendored
15
plugin/cache/item.go
vendored
|
@ -1,14 +1,18 @@
|
||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coredns/coredns/plugin/cache/freq"
|
"github.com/coredns/coredns/plugin/cache/freq"
|
||||||
|
"github.com/coredns/coredns/request"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type item struct {
|
type item struct {
|
||||||
|
Name string
|
||||||
|
QType uint16
|
||||||
Rcode int
|
Rcode int
|
||||||
AuthenticatedData bool
|
AuthenticatedData bool
|
||||||
RecursionAvailable bool
|
RecursionAvailable bool
|
||||||
|
@ -24,6 +28,10 @@ type item struct {
|
||||||
|
|
||||||
func newItem(m *dns.Msg, now time.Time, d time.Duration) *item {
|
func newItem(m *dns.Msg, now time.Time, d time.Duration) *item {
|
||||||
i := new(item)
|
i := new(item)
|
||||||
|
if len(m.Question) != 0 {
|
||||||
|
i.Name = m.Question[0].Name
|
||||||
|
i.QType = m.Question[0].Qtype
|
||||||
|
}
|
||||||
i.Rcode = m.Rcode
|
i.Rcode = m.Rcode
|
||||||
i.AuthenticatedData = m.AuthenticatedData
|
i.AuthenticatedData = m.AuthenticatedData
|
||||||
i.RecursionAvailable = m.RecursionAvailable
|
i.RecursionAvailable = m.RecursionAvailable
|
||||||
|
@ -87,3 +95,10 @@ func (i *item) ttl(now time.Time) int {
|
||||||
ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds())
|
ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds())
|
||||||
return ttl
|
return ttl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *item) matches(state request.Request) bool {
|
||||||
|
if state.QType() == i.QType && strings.EqualFold(state.QName(), i.Name) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue