plugin/transfer: fix go-routine leak (#4380)

PR #4161 is stalled. Tried to cherry pick the code from there, but that
led to conflicts, manually copying over while taking into account the
comments on that PR. Use that code and extend the error checking, don't
modify existing tests and make the badwriter test simpler.

Closes: #4161

Signed-off-by: Miek Gieben <miek@miek.nl>

add tests

Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
Miek Gieben 2021-01-13 09:16:01 +01:00 committed by GitHub
parent fd705b4783
commit d31b118978
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 10 deletions

View file

@ -0,0 +1,31 @@
package transfer
import (
"context"
"fmt"
"testing"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
)
type badwriter struct {
dns.ResponseWriter
count int
}
func (w *badwriter) WriteMsg(res *dns.Msg) error { return fmt.Errorf("failed to write msg") }
func TestWriteMessageFailed(t *testing.T) {
transfer := newTestTransfer()
ctx := context.TODO()
w := &badwriter{ResponseWriter: &test.ResponseWriter{}}
m := &dns.Msg{}
m.SetAxfr("example.org.")
_, err := transfer.ServeDNS(ctx, w, m)
if err == nil {
t.Error("Expected error, got none")
}
}

View file

@ -4,7 +4,6 @@ import (
"context"
"errors"
"net"
"sync"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
@ -107,11 +106,12 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
// Send response to client
ch := make(chan *dns.Envelope)
tr := new(dns.Transfer)
wg := new(sync.WaitGroup)
wg.Add(1)
errCh := make(chan error)
go func() {
tr.Out(w, r, ch)
wg.Done()
if err := tr.Out(w, r, ch); err != nil {
errCh <- err
}
close(errCh)
}()
rrs := []dns.RR{}
@ -123,7 +123,11 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
}
rrs = append(rrs, records...)
if len(rrs) > 500 {
ch <- &dns.Envelope{RR: rrs}
select {
case ch <- &dns.Envelope{RR: rrs}:
case err := <-errCh:
return dns.RcodeServerFailure, err
}
l += len(rrs)
rrs = []dns.RR{}
}
@ -134,7 +138,10 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
// need to return the SOA back to the client and return.
if len(rrs) == 1 && soa != nil { // soa should never be nil...
close(ch)
wg.Wait()
err := <-errCh
if err != nil {
return dns.RcodeServerFailure, err
}
m := new(dns.Msg)
m.SetReply(r)
@ -146,12 +153,20 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
}
if len(rrs) > 0 {
ch <- &dns.Envelope{RR: rrs}
select {
case ch <- &dns.Envelope{RR: rrs}:
case err := <-errCh:
return dns.RcodeServerFailure, err
}
l += len(rrs)
}
close(ch) // Even though we close the channel here, we still have
wg.Wait() // to wait before we can return and close the connection.
close(ch) // Even though we close the channel here, we still have
err = <-errCh // to wait before we can return and close the connection.
if err != nil {
return dns.RcodeServerFailure, err
}
logserial := uint32(0)
if soa != nil {