package errors import ( "bytes" "context" "errors" "fmt" golog "log" "regexp" "strings" "sync/atomic" "testing" "time" "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/pkg/dnstest" clog "github.com/coredns/coredns/plugin/pkg/log" "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" ) func TestErrors(t *testing.T) { buf := bytes.Buffer{} golog.SetOutput(&buf) em := errorHandler{} testErr := errors.New("test error") tests := []struct { next plugin.Handler expectedCode int expectedLog string expectedErr error }{ { next: genErrorHandler(dns.RcodeSuccess, nil), expectedCode: dns.RcodeSuccess, expectedLog: "", expectedErr: nil, }, { next: genErrorHandler(dns.RcodeNotAuth, testErr), expectedCode: dns.RcodeNotAuth, expectedLog: fmt.Sprintf("%d %s: %v\n", dns.RcodeNotAuth, "example.org. A", testErr), expectedErr: testErr, }, } ctx := context.TODO() req := new(dns.Msg) req.SetQuestion("example.org.", dns.TypeA) for i, tc := range tests { em.Next = tc.next buf.Reset() rec := dnstest.NewRecorder(&test.ResponseWriter{}) code, err := em.ServeDNS(ctx, rec, req) if err != tc.expectedErr { t.Errorf("Test %d: Expected error %v, but got %v", i, tc.expectedErr, err) } if code != tc.expectedCode { t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code) } if log := buf.String(); !strings.Contains(log, tc.expectedLog) { t.Errorf("Test %d: Expected log %q, but got %q", i, tc.expectedLog, log) } } } func TestLogPattern(t *testing.T) { type args struct { logCallback func(format string, v ...interface{}) } tests := []struct { name string args args want string }{ { name: "error log", args: args{logCallback: log.Errorf}, want: "[ERROR] plugin/errors: 4 errors like '^error.*!$' occurred in last 2s", }, { name: "warn log", args: args{logCallback: log.Warningf}, want: "[WARNING] plugin/errors: 4 errors like '^error.*!$' occurred in last 2s", }, { name: "info log", args: args{logCallback: log.Infof}, want: "[INFO] plugin/errors: 4 errors like '^error.*!$' occurred in last 2s", }, { name: "debug log", args: args{logCallback: log.Debugf}, want: "[DEBUG] plugin/errors: 4 errors like '^error.*!$' occurred in last 2s", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { buf := bytes.Buffer{} clog.D.Set() golog.SetOutput(&buf) h := &errorHandler{ patterns: []*pattern{{ count: 4, period: 2 * time.Second, pattern: regexp.MustCompile("^error.*!$"), logCallback: tt.args.logCallback, }}, } h.logPattern(0) if log := buf.String(); !strings.Contains(log, tt.want) { t.Errorf("Expected log %q, but got %q", tt.want, log) } }) } } func TestInc(t *testing.T) { h := &errorHandler{ stopFlag: 1, patterns: []*pattern{{ period: 2 * time.Second, pattern: regexp.MustCompile("^error.*!$"), }}, } ret := h.inc(0) if ret { t.Error("Unexpected return value, expected false, actual true") } h.stopFlag = 0 ret = h.inc(0) if !ret { t.Error("Unexpected return value, expected true, actual false") } expCnt := uint32(1) actCnt := atomic.LoadUint32(&h.patterns[0].count) if actCnt != expCnt { t.Errorf("Unexpected 'count', expected %d, actual %d", expCnt, actCnt) } t1 := h.patterns[0].timer() if t1 == nil { t.Error("Unexpected 'timer', expected not nil") } ret = h.inc(0) if !ret { t.Error("Unexpected return value, expected true, actual false") } expCnt = uint32(2) actCnt = atomic.LoadUint32(&h.patterns[0].count) if actCnt != expCnt { t.Errorf("Unexpected 'count', expected %d, actual %d", expCnt, actCnt) } t2 := h.patterns[0].timer() if t2 != t1 { t.Error("Unexpected 'timer', expected the same") } ret = t1.Stop() if !ret { t.Error("Timer was unexpectedly stopped before") } ret = t2.Stop() if ret { t.Error("Timer was unexpectedly not stopped before") } } func TestStop(t *testing.T) { buf := bytes.Buffer{} golog.SetOutput(&buf) h := &errorHandler{ patterns: []*pattern{{ period: 2 * time.Second, pattern: regexp.MustCompile("^error.*!$"), logCallback: log.Errorf, }}, } h.inc(0) h.inc(0) h.inc(0) expCnt := uint32(3) actCnt := atomic.LoadUint32(&h.patterns[0].count) if actCnt != expCnt { t.Fatalf("Unexpected initial 'count', expected %d, actual %d", expCnt, actCnt) } h.stop() expCnt = uint32(0) actCnt = atomic.LoadUint32(&h.patterns[0].count) if actCnt != expCnt { t.Errorf("Unexpected 'count', expected %d, actual %d", expCnt, actCnt) } expStop := uint32(1) actStop := h.stopFlag if actStop != expStop { t.Errorf("Unexpected 'stop', expected %d, actual %d", expStop, actStop) } t1 := h.patterns[0].timer() if t1 == nil { t.Error("Unexpected 'timer', expected not nil") } else if t1.Stop() { t.Error("Timer was unexpectedly not stopped before") } expLog := "3 errors like '^error.*!$' occurred in last 2s" if log := buf.String(); !strings.Contains(log, expLog) { t.Errorf("Expected log %q, but got %q", expLog, log) } } func genErrorHandler(rcode int, err error) plugin.Handler { return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { return rcode, err }) }