plugin/transfer: fix go-routine leak (#4380)
PR #4161 is stalled. Tried to cherry pick the code from there, but that led to conflicts, manually copying over while taking into account the comments on that PR. Use that code and extend the error checking, don't modify existing tests and make the badwriter test simpler. Closes: #4161 Signed-off-by: Miek Gieben <miek@miek.nl> add tests Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
parent
fd705b4783
commit
d31b118978
2 changed files with 56 additions and 10 deletions
31
plugin/transfer/failed_write_test.go
Normal file
31
plugin/transfer/failed_write_test.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package transfer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/plugin/test"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type badwriter struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
count int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *badwriter) WriteMsg(res *dns.Msg) error { return fmt.Errorf("failed to write msg") }
|
||||||
|
|
||||||
|
func TestWriteMessageFailed(t *testing.T) {
|
||||||
|
transfer := newTestTransfer()
|
||||||
|
ctx := context.TODO()
|
||||||
|
w := &badwriter{ResponseWriter: &test.ResponseWriter{}}
|
||||||
|
m := &dns.Msg{}
|
||||||
|
m.SetAxfr("example.org.")
|
||||||
|
|
||||||
|
_, err := transfer.ServeDNS(ctx, w, m)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error, got none")
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/coredns/coredns/plugin"
|
"github.com/coredns/coredns/plugin"
|
||||||
clog "github.com/coredns/coredns/plugin/pkg/log"
|
clog "github.com/coredns/coredns/plugin/pkg/log"
|
||||||
|
@ -107,11 +106,12 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
|
||||||
// Send response to client
|
// Send response to client
|
||||||
ch := make(chan *dns.Envelope)
|
ch := make(chan *dns.Envelope)
|
||||||
tr := new(dns.Transfer)
|
tr := new(dns.Transfer)
|
||||||
wg := new(sync.WaitGroup)
|
errCh := make(chan error)
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
go func() {
|
||||||
tr.Out(w, r, ch)
|
if err := tr.Out(w, r, ch); err != nil {
|
||||||
wg.Done()
|
errCh <- err
|
||||||
|
}
|
||||||
|
close(errCh)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
rrs := []dns.RR{}
|
rrs := []dns.RR{}
|
||||||
|
@ -123,7 +123,11 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
|
||||||
}
|
}
|
||||||
rrs = append(rrs, records...)
|
rrs = append(rrs, records...)
|
||||||
if len(rrs) > 500 {
|
if len(rrs) > 500 {
|
||||||
ch <- &dns.Envelope{RR: rrs}
|
select {
|
||||||
|
case ch <- &dns.Envelope{RR: rrs}:
|
||||||
|
case err := <-errCh:
|
||||||
|
return dns.RcodeServerFailure, err
|
||||||
|
}
|
||||||
l += len(rrs)
|
l += len(rrs)
|
||||||
rrs = []dns.RR{}
|
rrs = []dns.RR{}
|
||||||
}
|
}
|
||||||
|
@ -134,7 +138,10 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
|
||||||
// need to return the SOA back to the client and return.
|
// need to return the SOA back to the client and return.
|
||||||
if len(rrs) == 1 && soa != nil { // soa should never be nil...
|
if len(rrs) == 1 && soa != nil { // soa should never be nil...
|
||||||
close(ch)
|
close(ch)
|
||||||
wg.Wait()
|
err := <-errCh
|
||||||
|
if err != nil {
|
||||||
|
return dns.RcodeServerFailure, err
|
||||||
|
}
|
||||||
|
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
|
@ -146,12 +153,20 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(rrs) > 0 {
|
if len(rrs) > 0 {
|
||||||
ch <- &dns.Envelope{RR: rrs}
|
select {
|
||||||
|
case ch <- &dns.Envelope{RR: rrs}:
|
||||||
|
case err := <-errCh:
|
||||||
|
return dns.RcodeServerFailure, err
|
||||||
|
}
|
||||||
l += len(rrs)
|
l += len(rrs)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
close(ch) // Even though we close the channel here, we still have
|
close(ch) // Even though we close the channel here, we still have
|
||||||
wg.Wait() // to wait before we can return and close the connection.
|
err = <-errCh // to wait before we can return and close the connection.
|
||||||
|
if err != nil {
|
||||||
|
return dns.RcodeServerFailure, err
|
||||||
|
}
|
||||||
|
|
||||||
logserial := uint32(0)
|
logserial := uint32(0)
|
||||||
if soa != nil {
|
if soa != nil {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue