From f27799d1aa6285241f13d62408cd0a576d46f253 Mon Sep 17 00:00:00 2001 From: Milos Gajdos Date: Fri, 28 Jun 2024 22:03:22 +0100 Subject: [PATCH] Add custom TLS config to Redis We also update the Redis TLS config initialization in the app. Signed-off-by: Milos Gajdos --- configuration/configuration.go | 50 +++++++++++++++++++++++++++-- configuration/configuration_test.go | 29 +++++++++++++++-- registry/handlers/app.go | 31 ++++++++++++++++++ 3 files changed, 104 insertions(+), 6 deletions(-) diff --git a/configuration/configuration.go b/configuration/configuration.go index 884552da..253c0615 100644 --- a/configuration/configuration.go +++ b/configuration/configuration.go @@ -654,7 +654,12 @@ func Parse(rd io.Reader) (*Configuration, error) { } type Redis struct { - redis.UniversalOptions + redis.UniversalOptions `yaml:",inline"` + TLS struct { + Certificate string `yaml:"certificate,omitempty"` + Key string `yaml:"key,omitempty"` + ClientCAs []string `yaml:"clientcas,omitempty"` + } `yaml:"tls,omitempty"` } func (c Redis) MarshalYAML() (interface{}, error) { @@ -667,14 +672,19 @@ func (c Redis) MarshalYAML() (interface{}, error) { field := typ.Field(i) fieldValue := val.Field(i) - // ignore imports and funcs - if field.PkgPath != "" || fieldValue.Kind() == reflect.Func { + // ignore funcs fields in redis.UniversalOptions + if fieldValue.Kind() == reflect.Func { continue } fields[strings.ToLower(field.Name)] = fieldValue.Interface() } + // Add TLS fields if they're not empty + if c.TLS.Certificate != "" || c.TLS.Key != "" || len(c.TLS.ClientCAs) > 0 { + fields["tls"] = c.TLS + } + return fields, nil } @@ -715,6 +725,40 @@ func (c *Redis) UnmarshalYAML(unmarshal func(interface{}) error) error { } } + // Handle TLS fields + if tlsData, ok := fields["tls"]; ok { + tlsMap, ok := tlsData.(map[interface{}]interface{}) + if !ok { + return fmt.Errorf("invalid TLS data structure") + } + + if cert, ok := tlsMap["certificate"]; ok { + var isString bool + c.TLS.Certificate, isString = cert.(string) + if !isString { + return fmt.Errorf("Redis TLS certificate must be a string") + } + } + if key, ok := tlsMap["key"]; ok { + var isString bool + c.TLS.Key, isString = key.(string) + if !isString { + return fmt.Errorf("Redis TLS (private) key must be a string") + } + } + if cas, ok := tlsMap["clientcas"]; ok { + caList, ok := cas.([]interface{}) + if !ok { + return fmt.Errorf("invalid clientcas data structure") + } + for _, ca := range caList { + if caStr, ok := ca.(string); ok { + c.TLS.ClientCAs = append(c.TLS.ClientCAs, caStr) + } + } + } + } + return nil } diff --git a/configuration/configuration_test.go b/configuration/configuration_test.go index b7018807..73085367 100644 --- a/configuration/configuration_test.go +++ b/configuration/configuration_test.go @@ -132,7 +132,7 @@ var configStruct = Configuration{ }, }, Redis: Redis{ - redis.UniversalOptions{ + UniversalOptions: redis.UniversalOptions{ Addrs: []string{"localhost:6379"}, Username: "alice", Password: "123456", @@ -144,6 +144,15 @@ var configStruct = Configuration{ ReadTimeout: time.Millisecond * 10, WriteTimeout: time.Millisecond * 10, }, + TLS: struct { + Certificate string `yaml:"certificate,omitempty"` + Key string `yaml:"key,omitempty"` + ClientCAs []string `yaml:"clientcas,omitempty"` + }{ + Certificate: "/foo/cert.crt", + Key: "/foo/key.pem", + ClientCAs: []string{"/path/to/ca.pem"}, + }, }, } @@ -182,11 +191,17 @@ notifications: actions: - pull http: - clientcas: - - /path/to/ca.pem + tls: + clientcas: + - /path/to/ca.pem headers: X-Content-Type-Options: [nosniff] redis: + tls: + certificate: /foo/cert.crt + key: /foo/key.pem + clientcas: + - /path/to/ca.pem addrs: [localhost:6379] username: alice password: "123456" @@ -265,6 +280,7 @@ func (suite *ConfigSuite) TestParseSimple() { func (suite *ConfigSuite) TestParseInmemory() { suite.expectedConfig.Storage = Storage{"inmemory": Parameters{}} suite.expectedConfig.Log.Fields = nil + suite.expectedConfig.HTTP.TLS.ClientCAs = nil suite.expectedConfig.Redis = Redis{} config, err := Parse(bytes.NewReader([]byte(inmemoryConfigYamlV0_1))) @@ -285,6 +301,7 @@ func (suite *ConfigSuite) TestParseIncomplete() { suite.expectedConfig.Auth = Auth{"silly": Parameters{"realm": "silly"}} suite.expectedConfig.Notifications = Notifications{} suite.expectedConfig.HTTP.Headers = nil + suite.expectedConfig.HTTP.TLS.ClientCAs = nil suite.expectedConfig.Redis = Redis{} // Note: this also tests that REGISTRY_STORAGE and @@ -551,8 +568,14 @@ func copyConfig(config Configuration) *Configuration { for k, v := range config.HTTP.Headers { configCopy.HTTP.Headers[k] = v } + configCopy.HTTP.TLS.ClientCAs = make([]string, 0, len(config.HTTP.TLS.ClientCAs)) + configCopy.HTTP.TLS.ClientCAs = append(configCopy.HTTP.TLS.ClientCAs, config.HTTP.TLS.ClientCAs...) configCopy.Redis = config.Redis + configCopy.Redis.TLS.Certificate = config.Redis.TLS.Certificate + configCopy.Redis.TLS.Key = config.Redis.TLS.Key + configCopy.Redis.TLS.ClientCAs = make([]string, 0, len(config.Redis.TLS.ClientCAs)) + configCopy.Redis.TLS.ClientCAs = append(configCopy.Redis.TLS.ClientCAs, config.Redis.TLS.ClientCAs...) return configCopy } diff --git a/registry/handlers/app.go b/registry/handlers/app.go index e108dc2e..414ea8db 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -3,6 +3,8 @@ package handlers import ( "context" "crypto/rand" + "crypto/tls" + "crypto/x509" "expvar" "fmt" "math" @@ -492,6 +494,35 @@ func (app *App) configureRedis(cfg *configuration.Configuration) { return } + // redis TLS config + if cfg.Redis.TLS.Certificate != "" || cfg.Redis.TLS.Key != "" { + var err error + tlsConf := &tls.Config{} + tlsConf.Certificates = make([]tls.Certificate, 1) + tlsConf.Certificates[0], err = tls.LoadX509KeyPair(cfg.Redis.TLS.Certificate, cfg.Redis.TLS.Key) + if err != nil { + panic(err) + } + if len(cfg.Redis.TLS.ClientCAs) != 0 { + pool := x509.NewCertPool() + for _, ca := range cfg.Redis.TLS.ClientCAs { + caPem, err := os.ReadFile(ca) + if err != nil { + dcontext.GetLogger(app).Errorf("failed reading redis client CA: %v", err) + return + } + + if ok := pool.AppendCertsFromPEM(caPem); !ok { + dcontext.GetLogger(app).Error("could not add CA to pool") + return + } + } + tlsConf.ClientAuth = tls.RequireAndVerifyClientCert + tlsConf.ClientCAs = pool + } + cfg.Redis.UniversalOptions.TLSConfig = tlsConf + } + app.redis = app.createPool(cfg.Redis.UniversalOptions) // Enable metrics instrumentation.