diff --git a/context/logger.go b/context/logger.go index 3e5b81bb..f956a228 100644 --- a/context/logger.go +++ b/context/logger.go @@ -4,10 +4,16 @@ import ( "context" "fmt" "runtime" + "sync" "github.com/sirupsen/logrus" ) +var ( + defaultLogger *logrus.Entry = logrus.StandardLogger().WithField("go.version", runtime.Version()) + defaultLoggerMu sync.RWMutex +) + // Logger provides a leveled-logging interface. type Logger interface { // standard logger methods @@ -80,6 +86,18 @@ func GetLogger(ctx context.Context, keys ...interface{}) Logger { return getLogrusLogger(ctx, keys...) } +// SetDefaultLogger sets the default logger upon which to base new loggers. +func SetDefaultLogger(logger Logger) { + entry, ok := logger.(*logrus.Entry) + if !ok { + return + } + + defaultLoggerMu.Lock() + defaultLogger = entry + defaultLoggerMu.Unlock() +} + // GetLogrusLogger returns the logrus logger for the context. If one more keys // are provided, they will be resolved on the context and included in the // logger. Only use this function if specific logrus functionality is @@ -104,9 +122,9 @@ func getLogrusLogger(ctx context.Context, keys ...interface{}) *logrus.Entry { fields["instance.id"] = instanceID } - fields["go.version"] = runtime.Version() - // If no logger is found, just return the standard logger. - logger = logrus.StandardLogger().WithFields(fields) + defaultLoggerMu.RLock() + logger = defaultLogger.WithFields(fields) + defaultLoggerMu.RUnlock() } fields := logrus.Fields{} diff --git a/registry/registry.go b/registry/registry.go index ea4d46e6..05e92066 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -386,6 +386,7 @@ func configureLogging(ctx context.Context, config *configuration.Configuration) ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, fields...)) } + dcontext.SetDefaultLogger(dcontext.GetLogger(ctx)) return ctx, nil } diff --git a/registry/registry_test.go b/registry/registry_test.go index 9ff0ddfb..e0d5c218 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -25,7 +25,10 @@ import ( "time" "github.com/distribution/distribution/v3/configuration" + dcontext "github.com/distribution/distribution/v3/context" _ "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" + "github.com/sirupsen/logrus" + "gopkg.in/yaml.v2" ) // Tests to ensure nextProtos returns the correct protocols when: @@ -346,3 +349,59 @@ func TestRegistryUnsupportedCipherSuite(t *testing.T) { quit <- os.Interrupt time.Sleep(100 * time.Millisecond) } + +func TestConfigureLogging(t *testing.T) { + yamlConfig := `--- +log: + level: warn + fields: + foo: bar + baz: xyzzy +` + + var config configuration.Configuration + err := yaml.Unmarshal([]byte(yamlConfig), &config) + if err != nil { + t.Fatal("failed to parse config: ", err) + } + + ctx, err := configureLogging(context.Background(), &config) + if err != nil { + t.Fatal("failed to configure logging: ", err) + } + + // Check that the log level was set to Warn. + if logrus.IsLevelEnabled(logrus.InfoLevel) { + t.Error("expected Info to be disabled, is enabled") + } + + // Check that the returned context's logger includes the right fields. + logger := dcontext.GetLogger(ctx) + entry, ok := logger.(*logrus.Entry) + if !ok { + t.Fatalf("expected logger to be a *logrus.Entry, is: %T", entry) + } + val, ok := entry.Data["foo"].(string) + if !ok || val != "bar" { + t.Error("field foo not configured correctly; expected 'bar' got: ", val) + } + val, ok = entry.Data["baz"].(string) + if !ok || val != "xyzzy" { + t.Error("field baz not configured correctly; expected 'xyzzy' got: ", val) + } + + // Get a logger for a new, empty context and make sure it also has the right fields. + logger = dcontext.GetLogger(context.Background()) + entry, ok = logger.(*logrus.Entry) + if !ok { + t.Fatalf("expected logger to be a *logrus.Entry, is: %T", entry) + } + val, ok = entry.Data["foo"].(string) + if !ok || val != "bar" { + t.Error("field foo not configured correctly; expected 'bar' got: ", val) + } + val, ok = entry.Data["baz"].(string) + if !ok || val != "xyzzy" { + t.Error("field baz not configured correctly; expected 'xyzzy' got: ", val) + } +}