fs/http: reload client certificates on expiry

In corporate environments, client certificates have short life times
for added security, and they get renewed automatically. This means
that client certificate can expire in the middle of long running
command such as `mount`.

This commit attempts to reload the client certificates 30s before they
expire.

This will be active for all backends which use HTTP.
This commit is contained in:
Saleh Dindar 2023-10-24 22:01:42 -07:00 committed by Nick Craig-Wood
parent dcecb0ede4
commit f26d2c6ba8
2 changed files with 185 additions and 5 deletions

View file

@ -69,6 +69,13 @@ func NewTransportCustom(ctx context.Context, customize func(*http.Transport)) ht
if err != nil {
log.Fatalf("Failed to load --client-cert/--client-key pair: %v", err)
}
if cert.Leaf == nil {
// Leaf is always the first certificate
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
log.Fatalf("Failed to parse the certificate")
}
}
t.TLSClientConfig.Certificates = []tls.Certificate{cert}
}
@ -148,17 +155,24 @@ type Transport struct {
userAgent string
headers []*fs.HTTPOption
metrics *Metrics
// Filename of the client cert in case we need to reload it
clientCert string
clientKey string
// Mutex for serializing attempts at reloading the certificates
reloadMutex sync.Mutex
}
// newTransport wraps the http.Transport passed in and logs all
// roundtrips including the body if logBody is set.
func newTransport(ci *fs.ConfigInfo, transport *http.Transport) *Transport {
return &Transport{
Transport: transport,
dump: ci.Dump,
userAgent: ci.UserAgent,
headers: ci.Headers,
metrics: DefaultMetrics,
Transport: transport,
dump: ci.Dump,
userAgent: ci.UserAgent,
headers: ci.Headers,
metrics: DefaultMetrics,
clientCert: ci.ClientCert,
clientKey: ci.ClientKey,
}
}
@ -247,8 +261,44 @@ func cleanAuths(buf []byte) []byte {
return buf
}
var expireWindow = 30 * time.Second
func isCertificateExpired(cc *tls.Config) bool {
return len(cc.Certificates) > 0 && cc.Certificates[0].Leaf != nil && time.Until(cc.Certificates[0].Leaf.NotAfter) < expireWindow
}
func (t *Transport) reloadCertificates() {
t.reloadMutex.Lock()
defer t.reloadMutex.Unlock()
// Check that the certificate is expired before trying to reload it
// it might have been reloaded while we were waiting to lock the mutex
if !isCertificateExpired(t.TLSClientConfig) {
return
}
cert, err := tls.LoadX509KeyPair(t.clientCert, t.clientKey)
if err != nil {
log.Fatalf("Failed to load --client-cert/--client-key pair: %v", err)
}
// Check if we need to parse the certificate again, we need it
// for checking the expiration date
if cert.Leaf == nil {
// Leaf is always the first certificate
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
log.Fatalf("Failed to parse the certificate")
}
}
t.TLSClientConfig.Certificates = []tls.Certificate{cert}
}
// RoundTrip implements the RoundTripper interface.
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
// Check if certificates are being used and the certificates are expired
if isCertificateExpired(t.TLSClientConfig) {
t.reloadCertificates()
}
// Limit transactions per second if required
accounting.LimitTPS(req.Context())
// Force user agent

View file

@ -1,9 +1,24 @@
package fshttp
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/rclone/rclone/fs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCleanAuth(t *testing.T) {
@ -45,3 +60,118 @@ func TestCleanAuths(t *testing.T) {
assert.Equal(t, test.want, got, test.in)
}
}
var certSerial = int64(0)
// Create a test certificate and key pair that is valid for a specific
// duration
func createTestCert(validity time.Duration) (keyPEM []byte, certPEM []byte, err error) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
return
}
keyBytes := x509.MarshalPKCS1PrivateKey(key)
// PEM encoding of private key
keyPEM = pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: keyBytes,
},
)
// Now create the certificate
notBefore := time.Now()
notAfter := notBefore.Add(validity).Add(expireWindow)
certSerial += 1
template := x509.Certificate{
SerialNumber: big.NewInt(certSerial),
Subject: pkix.Name{CommonName: "localhost"},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: notBefore,
NotAfter: notAfter,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
return
}
certPEM = pem.EncodeToMemory(
&pem.Block{
Type: "CERTIFICATE",
Bytes: derBytes,
},
)
return
}
func writeTestCert(t *testing.T, ci *fs.ConfigInfo, validity time.Duration) {
keyPEM, certPEM, err := createTestCert(1 * time.Second)
assert.NoError(t, err, "Cannot create test cert")
err = os.WriteFile(ci.ClientCert, certPEM, 0666)
assert.NoError(t, err, "Failed to write cert")
err = os.WriteFile(ci.ClientKey, keyPEM, 0666)
assert.NoError(t, err, "Failed to write key")
}
func TestCertificates(t *testing.T) {
startTime := time.Now()
// Starting a TLS server
expectedSerial := int64(0)
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cert := r.TLS.PeerCertificates
require.Greater(t, len(cert), 0, "No certificates received")
expectedSerial += 1
assert.Equal(t, expectedSerial, cert[0].SerialNumber.Int64(), "Did not get the correct serial number in certificate")
// Check that the certificate hasn't expired. We cannot use cert validation
// functions because those check for signature as well and our certificates
// are not properly signed
if time.Now().After(cert[0].NotAfter) {
assert.Fail(t, "Certificate expired", "Certificate expires at %s, current time is %s", cert[0].NotAfter.Sub(startTime), time.Since(startTime))
}
// Write some test data to fullfil the request
w.Header().Set("Content-Type", "text/plain")
_, _ = fmt.Fprintln(w, "test data")
}))
defer ts.Close()
// Modify servers config to request a client certificate
// we cannot validate the certificate since we are not properly signing it
ts.TLS.ClientAuth = tls.RequestClientCert
// Set --client-cert and --client-key in config to
// a pair of temp files
// create a test cert/key pair and write it to the files
ctx := context.TODO()
ci := fs.GetConfig(ctx)
// Create a test certificate and write it to a temp file
ci.ClientCert = t.TempDir() + "client.cert"
ci.ClientKey = t.TempDir() + "client.key"
validity := 1 * time.Second
writeTestCert(t, ci, validity)
// Now create the client with the above settings
// we need to disable TLS verification since we don't
// care about server certificate
client := NewClient(ctx)
tt := client.Transport.(*Transport)
tt.TLSClientConfig.InsecureSkipVerify = true
// Now make requests, the first request should be within
// the valid window
_, err := client.Get(ts.URL)
assert.NoError(t, err)
// Wait for the 2* valid duration of the certificate so that has definitely expired
time.Sleep(2 * validity)
// Create a new cert and write it to files
writeTestCert(t, ci, validity)
// The new cert should be auto-loaded before we make this request
_, err = client.Get(ts.URL)
assert.NoError(t, err)
}