Add custom TLS config to Redis

We also update the Redis TLS config initialization in the app.

Signed-off-by: Milos Gajdos <milosthegajdos@gmail.com>
This commit is contained in:
Milos Gajdos 2024-06-28 22:03:22 +01:00
parent b63cbb3318
commit f27799d1aa
No known key found for this signature in database
3 changed files with 104 additions and 6 deletions

View file

@ -654,7 +654,12 @@ func Parse(rd io.Reader) (*Configuration, error) {
} }
type Redis struct { type Redis struct {
redis.UniversalOptions redis.UniversalOptions `yaml:",inline"`
TLS struct {
Certificate string `yaml:"certificate,omitempty"`
Key string `yaml:"key,omitempty"`
ClientCAs []string `yaml:"clientcas,omitempty"`
} `yaml:"tls,omitempty"`
} }
func (c Redis) MarshalYAML() (interface{}, error) { func (c Redis) MarshalYAML() (interface{}, error) {
@ -667,14 +672,19 @@ func (c Redis) MarshalYAML() (interface{}, error) {
field := typ.Field(i) field := typ.Field(i)
fieldValue := val.Field(i) fieldValue := val.Field(i)
// ignore imports and funcs // ignore funcs fields in redis.UniversalOptions
if field.PkgPath != "" || fieldValue.Kind() == reflect.Func { if fieldValue.Kind() == reflect.Func {
continue continue
} }
fields[strings.ToLower(field.Name)] = fieldValue.Interface() fields[strings.ToLower(field.Name)] = fieldValue.Interface()
} }
// Add TLS fields if they're not empty
if c.TLS.Certificate != "" || c.TLS.Key != "" || len(c.TLS.ClientCAs) > 0 {
fields["tls"] = c.TLS
}
return fields, nil return fields, nil
} }
@ -715,6 +725,40 @@ func (c *Redis) UnmarshalYAML(unmarshal func(interface{}) error) error {
} }
} }
// Handle TLS fields
if tlsData, ok := fields["tls"]; ok {
tlsMap, ok := tlsData.(map[interface{}]interface{})
if !ok {
return fmt.Errorf("invalid TLS data structure")
}
if cert, ok := tlsMap["certificate"]; ok {
var isString bool
c.TLS.Certificate, isString = cert.(string)
if !isString {
return fmt.Errorf("Redis TLS certificate must be a string")
}
}
if key, ok := tlsMap["key"]; ok {
var isString bool
c.TLS.Key, isString = key.(string)
if !isString {
return fmt.Errorf("Redis TLS (private) key must be a string")
}
}
if cas, ok := tlsMap["clientcas"]; ok {
caList, ok := cas.([]interface{})
if !ok {
return fmt.Errorf("invalid clientcas data structure")
}
for _, ca := range caList {
if caStr, ok := ca.(string); ok {
c.TLS.ClientCAs = append(c.TLS.ClientCAs, caStr)
}
}
}
}
return nil return nil
} }

View file

@ -132,7 +132,7 @@ var configStruct = Configuration{
}, },
}, },
Redis: Redis{ Redis: Redis{
redis.UniversalOptions{ UniversalOptions: redis.UniversalOptions{
Addrs: []string{"localhost:6379"}, Addrs: []string{"localhost:6379"},
Username: "alice", Username: "alice",
Password: "123456", Password: "123456",
@ -144,6 +144,15 @@ var configStruct = Configuration{
ReadTimeout: time.Millisecond * 10, ReadTimeout: time.Millisecond * 10,
WriteTimeout: time.Millisecond * 10, WriteTimeout: time.Millisecond * 10,
}, },
TLS: struct {
Certificate string `yaml:"certificate,omitempty"`
Key string `yaml:"key,omitempty"`
ClientCAs []string `yaml:"clientcas,omitempty"`
}{
Certificate: "/foo/cert.crt",
Key: "/foo/key.pem",
ClientCAs: []string{"/path/to/ca.pem"},
},
}, },
} }
@ -182,11 +191,17 @@ notifications:
actions: actions:
- pull - pull
http: http:
tls:
clientcas: clientcas:
- /path/to/ca.pem - /path/to/ca.pem
headers: headers:
X-Content-Type-Options: [nosniff] X-Content-Type-Options: [nosniff]
redis: redis:
tls:
certificate: /foo/cert.crt
key: /foo/key.pem
clientcas:
- /path/to/ca.pem
addrs: [localhost:6379] addrs: [localhost:6379]
username: alice username: alice
password: "123456" password: "123456"
@ -265,6 +280,7 @@ func (suite *ConfigSuite) TestParseSimple() {
func (suite *ConfigSuite) TestParseInmemory() { func (suite *ConfigSuite) TestParseInmemory() {
suite.expectedConfig.Storage = Storage{"inmemory": Parameters{}} suite.expectedConfig.Storage = Storage{"inmemory": Parameters{}}
suite.expectedConfig.Log.Fields = nil suite.expectedConfig.Log.Fields = nil
suite.expectedConfig.HTTP.TLS.ClientCAs = nil
suite.expectedConfig.Redis = Redis{} suite.expectedConfig.Redis = Redis{}
config, err := Parse(bytes.NewReader([]byte(inmemoryConfigYamlV0_1))) config, err := Parse(bytes.NewReader([]byte(inmemoryConfigYamlV0_1)))
@ -285,6 +301,7 @@ func (suite *ConfigSuite) TestParseIncomplete() {
suite.expectedConfig.Auth = Auth{"silly": Parameters{"realm": "silly"}} suite.expectedConfig.Auth = Auth{"silly": Parameters{"realm": "silly"}}
suite.expectedConfig.Notifications = Notifications{} suite.expectedConfig.Notifications = Notifications{}
suite.expectedConfig.HTTP.Headers = nil suite.expectedConfig.HTTP.Headers = nil
suite.expectedConfig.HTTP.TLS.ClientCAs = nil
suite.expectedConfig.Redis = Redis{} suite.expectedConfig.Redis = Redis{}
// Note: this also tests that REGISTRY_STORAGE and // Note: this also tests that REGISTRY_STORAGE and
@ -551,8 +568,14 @@ func copyConfig(config Configuration) *Configuration {
for k, v := range config.HTTP.Headers { for k, v := range config.HTTP.Headers {
configCopy.HTTP.Headers[k] = v configCopy.HTTP.Headers[k] = v
} }
configCopy.HTTP.TLS.ClientCAs = make([]string, 0, len(config.HTTP.TLS.ClientCAs))
configCopy.HTTP.TLS.ClientCAs = append(configCopy.HTTP.TLS.ClientCAs, config.HTTP.TLS.ClientCAs...)
configCopy.Redis = config.Redis configCopy.Redis = config.Redis
configCopy.Redis.TLS.Certificate = config.Redis.TLS.Certificate
configCopy.Redis.TLS.Key = config.Redis.TLS.Key
configCopy.Redis.TLS.ClientCAs = make([]string, 0, len(config.Redis.TLS.ClientCAs))
configCopy.Redis.TLS.ClientCAs = append(configCopy.Redis.TLS.ClientCAs, config.Redis.TLS.ClientCAs...)
return configCopy return configCopy
} }

View file

@ -3,6 +3,8 @@ package handlers
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/tls"
"crypto/x509"
"expvar" "expvar"
"fmt" "fmt"
"math" "math"
@ -492,6 +494,35 @@ func (app *App) configureRedis(cfg *configuration.Configuration) {
return return
} }
// redis TLS config
if cfg.Redis.TLS.Certificate != "" || cfg.Redis.TLS.Key != "" {
var err error
tlsConf := &tls.Config{}
tlsConf.Certificates = make([]tls.Certificate, 1)
tlsConf.Certificates[0], err = tls.LoadX509KeyPair(cfg.Redis.TLS.Certificate, cfg.Redis.TLS.Key)
if err != nil {
panic(err)
}
if len(cfg.Redis.TLS.ClientCAs) != 0 {
pool := x509.NewCertPool()
for _, ca := range cfg.Redis.TLS.ClientCAs {
caPem, err := os.ReadFile(ca)
if err != nil {
dcontext.GetLogger(app).Errorf("failed reading redis client CA: %v", err)
return
}
if ok := pool.AppendCertsFromPEM(caPem); !ok {
dcontext.GetLogger(app).Error("could not add CA to pool")
return
}
}
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
tlsConf.ClientCAs = pool
}
cfg.Redis.UniversalOptions.TLSConfig = tlsConf
}
app.redis = app.createPool(cfg.Redis.UniversalOptions) app.redis = app.createPool(cfg.Redis.UniversalOptions)
// Enable metrics instrumentation. // Enable metrics instrumentation.