diff --git a/plugin/transfer/failed_write_test.go b/plugin/transfer/failed_write_test.go new file mode 100644 index 000000000..90b5c4de2 --- /dev/null +++ b/plugin/transfer/failed_write_test.go @@ -0,0 +1,31 @@ +package transfer + +import ( + "context" + "fmt" + "testing" + + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +type badwriter struct { + dns.ResponseWriter + count int +} + +func (w *badwriter) WriteMsg(res *dns.Msg) error { return fmt.Errorf("failed to write msg") } + +func TestWriteMessageFailed(t *testing.T) { + transfer := newTestTransfer() + ctx := context.TODO() + w := &badwriter{ResponseWriter: &test.ResponseWriter{}} + m := &dns.Msg{} + m.SetAxfr("example.org.") + + _, err := transfer.ServeDNS(ctx, w, m) + if err == nil { + t.Error("Expected error, got none") + } +} diff --git a/plugin/transfer/transfer.go b/plugin/transfer/transfer.go index 3558f2e0f..45251cda0 100644 --- a/plugin/transfer/transfer.go +++ b/plugin/transfer/transfer.go @@ -4,7 +4,6 @@ import ( "context" "errors" "net" - "sync" "github.com/coredns/coredns/plugin" clog "github.com/coredns/coredns/plugin/pkg/log" @@ -107,11 +106,12 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms // Send response to client ch := make(chan *dns.Envelope) tr := new(dns.Transfer) - wg := new(sync.WaitGroup) - wg.Add(1) + errCh := make(chan error) go func() { - tr.Out(w, r, ch) - wg.Done() + if err := tr.Out(w, r, ch); err != nil { + errCh <- err + } + close(errCh) }() rrs := []dns.RR{} @@ -123,7 +123,11 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms } rrs = append(rrs, records...) if len(rrs) > 500 { - ch <- &dns.Envelope{RR: rrs} + select { + case ch <- &dns.Envelope{RR: rrs}: + case err := <-errCh: + return dns.RcodeServerFailure, err + } l += len(rrs) rrs = []dns.RR{} } @@ -134,7 +138,10 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms // need to return the SOA back to the client and return. if len(rrs) == 1 && soa != nil { // soa should never be nil... close(ch) - wg.Wait() + err := <-errCh + if err != nil { + return dns.RcodeServerFailure, err + } m := new(dns.Msg) m.SetReply(r) @@ -146,12 +153,20 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms } if len(rrs) > 0 { - ch <- &dns.Envelope{RR: rrs} + select { + case ch <- &dns.Envelope{RR: rrs}: + case err := <-errCh: + return dns.RcodeServerFailure, err + } l += len(rrs) + } - close(ch) // Even though we close the channel here, we still have - wg.Wait() // to wait before we can return and close the connection. + close(ch) // Even though we close the channel here, we still have + err = <-errCh // to wait before we can return and close the connection. + if err != nil { + return dns.RcodeServerFailure, err + } logserial := uint32(0) if soa != nil {