forked from TrueCloudLab/distribution
Add custom TLS config to Redis
We also update the Redis TLS config initialization in the app. Signed-off-by: Milos Gajdos <milosthegajdos@gmail.com>
This commit is contained in:
parent
b63cbb3318
commit
f27799d1aa
3 changed files with 104 additions and 6 deletions
|
@ -654,7 +654,12 @@ func Parse(rd io.Reader) (*Configuration, error) {
|
|||
}
|
||||
|
||||
type Redis struct {
|
||||
redis.UniversalOptions
|
||||
redis.UniversalOptions `yaml:",inline"`
|
||||
TLS struct {
|
||||
Certificate string `yaml:"certificate,omitempty"`
|
||||
Key string `yaml:"key,omitempty"`
|
||||
ClientCAs []string `yaml:"clientcas,omitempty"`
|
||||
} `yaml:"tls,omitempty"`
|
||||
}
|
||||
|
||||
func (c Redis) MarshalYAML() (interface{}, error) {
|
||||
|
@ -667,14 +672,19 @@ func (c Redis) MarshalYAML() (interface{}, error) {
|
|||
field := typ.Field(i)
|
||||
fieldValue := val.Field(i)
|
||||
|
||||
// ignore imports and funcs
|
||||
if field.PkgPath != "" || fieldValue.Kind() == reflect.Func {
|
||||
// ignore funcs fields in redis.UniversalOptions
|
||||
if fieldValue.Kind() == reflect.Func {
|
||||
continue
|
||||
}
|
||||
|
||||
fields[strings.ToLower(field.Name)] = fieldValue.Interface()
|
||||
}
|
||||
|
||||
// Add TLS fields if they're not empty
|
||||
if c.TLS.Certificate != "" || c.TLS.Key != "" || len(c.TLS.ClientCAs) > 0 {
|
||||
fields["tls"] = c.TLS
|
||||
}
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
|
@ -715,6 +725,40 @@ func (c *Redis) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Handle TLS fields
|
||||
if tlsData, ok := fields["tls"]; ok {
|
||||
tlsMap, ok := tlsData.(map[interface{}]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid TLS data structure")
|
||||
}
|
||||
|
||||
if cert, ok := tlsMap["certificate"]; ok {
|
||||
var isString bool
|
||||
c.TLS.Certificate, isString = cert.(string)
|
||||
if !isString {
|
||||
return fmt.Errorf("Redis TLS certificate must be a string")
|
||||
}
|
||||
}
|
||||
if key, ok := tlsMap["key"]; ok {
|
||||
var isString bool
|
||||
c.TLS.Key, isString = key.(string)
|
||||
if !isString {
|
||||
return fmt.Errorf("Redis TLS (private) key must be a string")
|
||||
}
|
||||
}
|
||||
if cas, ok := tlsMap["clientcas"]; ok {
|
||||
caList, ok := cas.([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid clientcas data structure")
|
||||
}
|
||||
for _, ca := range caList {
|
||||
if caStr, ok := ca.(string); ok {
|
||||
c.TLS.ClientCAs = append(c.TLS.ClientCAs, caStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ var configStruct = Configuration{
|
|||
},
|
||||
},
|
||||
Redis: Redis{
|
||||
redis.UniversalOptions{
|
||||
UniversalOptions: redis.UniversalOptions{
|
||||
Addrs: []string{"localhost:6379"},
|
||||
Username: "alice",
|
||||
Password: "123456",
|
||||
|
@ -144,6 +144,15 @@ var configStruct = Configuration{
|
|||
ReadTimeout: time.Millisecond * 10,
|
||||
WriteTimeout: time.Millisecond * 10,
|
||||
},
|
||||
TLS: struct {
|
||||
Certificate string `yaml:"certificate,omitempty"`
|
||||
Key string `yaml:"key,omitempty"`
|
||||
ClientCAs []string `yaml:"clientcas,omitempty"`
|
||||
}{
|
||||
Certificate: "/foo/cert.crt",
|
||||
Key: "/foo/key.pem",
|
||||
ClientCAs: []string{"/path/to/ca.pem"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -182,11 +191,17 @@ notifications:
|
|||
actions:
|
||||
- pull
|
||||
http:
|
||||
tls:
|
||||
clientcas:
|
||||
- /path/to/ca.pem
|
||||
headers:
|
||||
X-Content-Type-Options: [nosniff]
|
||||
redis:
|
||||
tls:
|
||||
certificate: /foo/cert.crt
|
||||
key: /foo/key.pem
|
||||
clientcas:
|
||||
- /path/to/ca.pem
|
||||
addrs: [localhost:6379]
|
||||
username: alice
|
||||
password: "123456"
|
||||
|
@ -265,6 +280,7 @@ func (suite *ConfigSuite) TestParseSimple() {
|
|||
func (suite *ConfigSuite) TestParseInmemory() {
|
||||
suite.expectedConfig.Storage = Storage{"inmemory": Parameters{}}
|
||||
suite.expectedConfig.Log.Fields = nil
|
||||
suite.expectedConfig.HTTP.TLS.ClientCAs = nil
|
||||
suite.expectedConfig.Redis = Redis{}
|
||||
|
||||
config, err := Parse(bytes.NewReader([]byte(inmemoryConfigYamlV0_1)))
|
||||
|
@ -285,6 +301,7 @@ func (suite *ConfigSuite) TestParseIncomplete() {
|
|||
suite.expectedConfig.Auth = Auth{"silly": Parameters{"realm": "silly"}}
|
||||
suite.expectedConfig.Notifications = Notifications{}
|
||||
suite.expectedConfig.HTTP.Headers = nil
|
||||
suite.expectedConfig.HTTP.TLS.ClientCAs = nil
|
||||
suite.expectedConfig.Redis = Redis{}
|
||||
|
||||
// Note: this also tests that REGISTRY_STORAGE and
|
||||
|
@ -551,8 +568,14 @@ func copyConfig(config Configuration) *Configuration {
|
|||
for k, v := range config.HTTP.Headers {
|
||||
configCopy.HTTP.Headers[k] = v
|
||||
}
|
||||
configCopy.HTTP.TLS.ClientCAs = make([]string, 0, len(config.HTTP.TLS.ClientCAs))
|
||||
configCopy.HTTP.TLS.ClientCAs = append(configCopy.HTTP.TLS.ClientCAs, config.HTTP.TLS.ClientCAs...)
|
||||
|
||||
configCopy.Redis = config.Redis
|
||||
configCopy.Redis.TLS.Certificate = config.Redis.TLS.Certificate
|
||||
configCopy.Redis.TLS.Key = config.Redis.TLS.Key
|
||||
configCopy.Redis.TLS.ClientCAs = make([]string, 0, len(config.Redis.TLS.ClientCAs))
|
||||
configCopy.Redis.TLS.ClientCAs = append(configCopy.Redis.TLS.ClientCAs, config.Redis.TLS.ClientCAs...)
|
||||
|
||||
return configCopy
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ package handlers
|
|||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"math"
|
||||
|
@ -492,6 +494,35 @@ func (app *App) configureRedis(cfg *configuration.Configuration) {
|
|||
return
|
||||
}
|
||||
|
||||
// redis TLS config
|
||||
if cfg.Redis.TLS.Certificate != "" || cfg.Redis.TLS.Key != "" {
|
||||
var err error
|
||||
tlsConf := &tls.Config{}
|
||||
tlsConf.Certificates = make([]tls.Certificate, 1)
|
||||
tlsConf.Certificates[0], err = tls.LoadX509KeyPair(cfg.Redis.TLS.Certificate, cfg.Redis.TLS.Key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(cfg.Redis.TLS.ClientCAs) != 0 {
|
||||
pool := x509.NewCertPool()
|
||||
for _, ca := range cfg.Redis.TLS.ClientCAs {
|
||||
caPem, err := os.ReadFile(ca)
|
||||
if err != nil {
|
||||
dcontext.GetLogger(app).Errorf("failed reading redis client CA: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if ok := pool.AppendCertsFromPEM(caPem); !ok {
|
||||
dcontext.GetLogger(app).Error("could not add CA to pool")
|
||||
return
|
||||
}
|
||||
}
|
||||
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
tlsConf.ClientCAs = pool
|
||||
}
|
||||
cfg.Redis.UniversalOptions.TLSConfig = tlsConf
|
||||
}
|
||||
|
||||
app.redis = app.createPool(cfg.Redis.UniversalOptions)
|
||||
|
||||
// Enable metrics instrumentation.
|
||||
|
|
Loading…
Reference in a new issue