diff --git a/api/log/log.go b/api/log/log.go index dc030c39..687d61c6 100644 --- a/api/log/log.go +++ b/api/log/log.go @@ -7,8 +7,6 @@ import ( "os" "github.com/pkg/errors" - - "github.com/smallstep/certificates/logging" ) // StackTracedError is the set of errors implementing the StackTrace function. @@ -21,16 +19,21 @@ type StackTracedError interface { StackTrace() errors.StackTrace } +type fieldCarrier interface { + WithFields(map[string]any) + Fields() map[string]any +} + // Error adds to the response writer the given error if it implements // logging.ResponseLogger. If it does not implement it, then writes the error // using the log package. func Error(rw http.ResponseWriter, err error) { - rl, ok := rw.(logging.ResponseLogger) + fc, ok := rw.(fieldCarrier) if !ok { return } - rl.WithFields(map[string]interface{}{ + fc.WithFields(map[string]any{ "error": err, }) @@ -39,8 +42,8 @@ func Error(rw http.ResponseWriter, err error) { } var st StackTracedError - if !errors.As(err, &st) { - rl.WithFields(map[string]interface{}{ + if errors.As(err, &st) { + fc.WithFields(map[string]any{ "stack-trace": fmt.Sprintf("%+v", st.StackTrace()), }) } @@ -48,9 +51,9 @@ func Error(rw http.ResponseWriter, err error) { // EnabledResponse log the response object if it implements the EnableLogger // interface. -func EnabledResponse(rw http.ResponseWriter, v interface{}) { +func EnabledResponse(rw http.ResponseWriter, v any) { type enableLogger interface { - ToLog() (interface{}, error) + ToLog() (any, error) } if el, ok := v.(enableLogger); ok { @@ -61,8 +64,8 @@ func EnabledResponse(rw http.ResponseWriter, v interface{}) { return } - if rl, ok := rw.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ + if rl, ok := rw.(fieldCarrier); ok { + rl.WithFields(map[string]any{ "response": out, }) } diff --git a/api/log/log_test.go b/api/log/log_test.go index fcd3ea2b..7c08b771 100644 --- a/api/log/log_test.go +++ b/api/log/log_test.go @@ -1,43 +1,78 @@ package log import ( - "errors" "net/http" "net/http/httptest" - "reflect" "testing" + "unsafe" + + pkgerrors "github.com/pkg/errors" + "github.com/stretchr/testify/assert" "github.com/smallstep/certificates/logging" ) -func TestError(t *testing.T) { - theError := errors.New("the error") +type stackTracedError struct{} - type args struct { - rw http.ResponseWriter - err error +func (stackTracedError) Error() string { + return "a stacktraced error" +} + +func (stackTracedError) StackTrace() pkgerrors.StackTrace { + f := struct{}{} + return pkgerrors.StackTrace{ // fake stacktrace + pkgerrors.Frame(unsafe.Pointer(&f)), + pkgerrors.Frame(unsafe.Pointer(&f)), } +} + +func TestError(t *testing.T) { tests := []struct { - name string - args args - withFields bool + name string + error + rw http.ResponseWriter + isFieldCarrier bool + stepDebug bool + expectStackTrace bool }{ - {"normalLogger", args{httptest.NewRecorder(), theError}, false}, - {"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true}, + {"noLogger", nil, nil, false, false, false}, + {"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false}, + {"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false}, + {"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false}, + {"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false}, + {"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true}, + {"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - Error(tt.args.rw, tt.args.err) - if tt.withFields { - if rl, ok := tt.args.rw.(logging.ResponseLogger); ok { - fields := rl.Fields() - if !reflect.DeepEqual(fields["error"], theError) { - t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError) - } - } else { - t.Error("ResponseWriter does not implement logging.ResponseLogger") - } + if tt.stepDebug { + t.Setenv("STEPDEBUG", "1") + } else { + t.Setenv("STEPDEBUG", "0") + } + + Error(tt.rw, tt.error) + + // return early if test case doesn't use logger + if !tt.isFieldCarrier { + return + } + + fields := tt.rw.(logging.ResponseLogger).Fields() + + // expect the error field to be (not) set and to be the same error that was fed to Error + if tt.error == nil { + assert.Nil(t, fields["error"]) + } else { + assert.Same(t, tt.error, fields["error"]) + } + + // check if stack-trace is set when expected + if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace { + t.Error(`ResponseLogger["stack-trace"] not set`) + } else if !tt.expectStackTrace && hasStackTrace { + t.Error(`ResponseLogger["stack-trace"] was set`) } }) }