coredns/plugin/pkg/watch/watcher.go
2018-08-14 08:55:55 -07:00

178 lines
4.1 KiB
Go

package watch
import (
"fmt"
"io"
"sync"
"github.com/miekg/dns"
"github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/request"
)
// Watcher handles watch creation, cancellation, and processing.
type Watcher interface {
// Watch monitors a client stream and creates and cancels watches.
Watch(pb.DnsService_WatchServer) error
// Stop cancels open watches and stops the watch processing go routine.
Stop()
}
// Manager contains all the data needed to manage watches
type Manager struct {
changes Chan
stopper chan bool
counter int64
watches map[string]watchlist
plugins []Watchable
mutex sync.Mutex
}
type watchlist map[int64]pb.DnsService_WatchServer
// NewWatcher creates a Watcher, which is used to manage watched names.
func NewWatcher(plugins []Watchable) *Manager {
w := &Manager{changes: make(Chan), stopper: make(chan bool), watches: make(map[string]watchlist), plugins: plugins}
for _, p := range plugins {
p.SetWatchChan(w.changes)
}
go w.process()
return w
}
func (w *Manager) nextID() int64 {
w.mutex.Lock()
w.counter++
id := w.counter
w.mutex.Unlock()
return id
}
// Watch monitors a client stream and creates and cancels watches.
func (w *Manager) Watch(stream pb.DnsService_WatchServer) error {
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
create := in.GetCreateRequest()
if create != nil {
msg := new(dns.Msg)
err := msg.Unpack(create.Query.Msg)
if err != nil {
log.Warningf("Could not decode watch request: %s\n", err)
stream.Send(&pb.WatchResponse{Err: "could not decode request"})
continue
}
id := w.nextID()
if err := stream.Send(&pb.WatchResponse{WatchId: id, Created: true}); err != nil {
// if we fail to notify client of watch creation, don't create the watch
continue
}
// Normalize qname
qname := (&request.Request{Req: msg}).Name()
w.mutex.Lock()
if _, ok := w.watches[qname]; !ok {
w.watches[qname] = make(watchlist)
}
w.watches[qname][id] = stream
w.mutex.Unlock()
for _, p := range w.plugins {
err := p.Watch(qname)
if err != nil {
log.Warningf("Failed to start watch for %s in plugin %s: %s\n", qname, p.Name(), err)
stream.Send(&pb.WatchResponse{Err: fmt.Sprintf("failed to start watch for %s in plugin %s", qname, p.Name())})
}
}
continue
}
cancel := in.GetCancelRequest()
if cancel != nil {
w.mutex.Lock()
for qname, wl := range w.watches {
ws, ok := wl[cancel.WatchId]
if !ok {
continue
}
// only allow cancels from the client that started it
// TODO: test what happens if a stream tries to cancel a watchID that it doesn't own
if ws != stream {
continue
}
delete(wl, cancel.WatchId)
// if there are no more watches for this qname, we should tell the plugins
if len(wl) == 0 {
for _, p := range w.plugins {
p.StopWatching(qname)
}
delete(w.watches, qname)
}
// let the client know we canceled the watch
stream.Send(&pb.WatchResponse{WatchId: cancel.WatchId, Canceled: true})
}
w.mutex.Unlock()
continue
}
}
}
func (w *Manager) process() {
for {
select {
case <-w.stopper:
return
case changed := <-w.changes:
w.mutex.Lock()
for qname, wl := range w.watches {
if plugin.Zones([]string{changed}).Matches(qname) == "" {
continue
}
for id, stream := range wl {
wr := pb.WatchResponse{WatchId: id, Qname: qname}
err := stream.Send(&wr)
if err != nil {
log.Warningf("Error sending change for %s to watch %d: %s. Removing watch.\n", qname, id, err)
delete(w.watches[qname], id)
}
}
}
w.mutex.Unlock()
}
}
}
// Stop cancels open watches and stops the watch processing go routine.
func (w *Manager) Stop() {
w.stopper <- true
w.mutex.Lock()
for wn, wl := range w.watches {
for id, stream := range wl {
wr := pb.WatchResponse{WatchId: id, Canceled: true}
err := stream.Send(&wr)
if err != nil {
log.Warningf("Error notifying client of cancellation: %s\n", err)
}
}
delete(w.watches, wn)
}
w.mutex.Unlock()
}