rclone/fs/fshttp/http_test.go
Saleh Dindar f26d2c6ba8 fs/http: reload client certificates on expiry
In corporate environments, client certificates have short life times
for added security, and they get renewed automatically. This means
that client certificate can expire in the middle of long running
command such as `mount`.

This commit attempts to reload the client certificates 30s before they
expire.

This will be active for all backends which use HTTP.
2024-07-24 15:02:32 +01:00

177 lines
5.6 KiB
Go

package fshttp
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/rclone/rclone/fs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCleanAuth(t *testing.T) {
for _, test := range []struct {
in string
want string
}{
{"", ""},
{"floo", "floo"},
{"Authorization: ", "Authorization: "},
{"Authorization: \n", "Authorization: \n"},
{"Authorization: A", "Authorization: X"},
{"Authorization: A\n", "Authorization: X\n"},
{"Authorization: AAAA", "Authorization: XXXX"},
{"Authorization: AAAA\n", "Authorization: XXXX\n"},
{"Authorization: AAAAA", "Authorization: XXXX"},
{"Authorization: AAAAA\n", "Authorization: XXXX\n"},
{"Authorization: AAAA\n", "Authorization: XXXX\n"},
{"Authorization: AAAAAAAAA\nPotato: Help\n", "Authorization: XXXX\nPotato: Help\n"},
{"Sausage: 1\nAuthorization: AAAAAAAAA\nPotato: Help\n", "Sausage: 1\nAuthorization: XXXX\nPotato: Help\n"},
} {
got := string(cleanAuth([]byte(test.in), authBufs[0]))
assert.Equal(t, test.want, got, test.in)
}
}
func TestCleanAuths(t *testing.T) {
for _, test := range []struct {
in string
want string
}{
{"", ""},
{"floo", "floo"},
{"Authorization: AAAAAAAAA\nPotato: Help\n", "Authorization: XXXX\nPotato: Help\n"},
{"X-Auth-Token: AAAAAAAAA\nPotato: Help\n", "X-Auth-Token: XXXX\nPotato: Help\n"},
{"X-Auth-Token: AAAAAAAAA\nAuthorization: AAAAAAAAA\nPotato: Help\n", "X-Auth-Token: XXXX\nAuthorization: XXXX\nPotato: Help\n"},
} {
got := string(cleanAuths([]byte(test.in)))
assert.Equal(t, test.want, got, test.in)
}
}
var certSerial = int64(0)
// Create a test certificate and key pair that is valid for a specific
// duration
func createTestCert(validity time.Duration) (keyPEM []byte, certPEM []byte, err error) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
return
}
keyBytes := x509.MarshalPKCS1PrivateKey(key)
// PEM encoding of private key
keyPEM = pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: keyBytes,
},
)
// Now create the certificate
notBefore := time.Now()
notAfter := notBefore.Add(validity).Add(expireWindow)
certSerial += 1
template := x509.Certificate{
SerialNumber: big.NewInt(certSerial),
Subject: pkix.Name{CommonName: "localhost"},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: notBefore,
NotAfter: notAfter,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
return
}
certPEM = pem.EncodeToMemory(
&pem.Block{
Type: "CERTIFICATE",
Bytes: derBytes,
},
)
return
}
func writeTestCert(t *testing.T, ci *fs.ConfigInfo, validity time.Duration) {
keyPEM, certPEM, err := createTestCert(1 * time.Second)
assert.NoError(t, err, "Cannot create test cert")
err = os.WriteFile(ci.ClientCert, certPEM, 0666)
assert.NoError(t, err, "Failed to write cert")
err = os.WriteFile(ci.ClientKey, keyPEM, 0666)
assert.NoError(t, err, "Failed to write key")
}
func TestCertificates(t *testing.T) {
startTime := time.Now()
// Starting a TLS server
expectedSerial := int64(0)
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cert := r.TLS.PeerCertificates
require.Greater(t, len(cert), 0, "No certificates received")
expectedSerial += 1
assert.Equal(t, expectedSerial, cert[0].SerialNumber.Int64(), "Did not get the correct serial number in certificate")
// Check that the certificate hasn't expired. We cannot use cert validation
// functions because those check for signature as well and our certificates
// are not properly signed
if time.Now().After(cert[0].NotAfter) {
assert.Fail(t, "Certificate expired", "Certificate expires at %s, current time is %s", cert[0].NotAfter.Sub(startTime), time.Since(startTime))
}
// Write some test data to fullfil the request
w.Header().Set("Content-Type", "text/plain")
_, _ = fmt.Fprintln(w, "test data")
}))
defer ts.Close()
// Modify servers config to request a client certificate
// we cannot validate the certificate since we are not properly signing it
ts.TLS.ClientAuth = tls.RequestClientCert
// Set --client-cert and --client-key in config to
// a pair of temp files
// create a test cert/key pair and write it to the files
ctx := context.TODO()
ci := fs.GetConfig(ctx)
// Create a test certificate and write it to a temp file
ci.ClientCert = t.TempDir() + "client.cert"
ci.ClientKey = t.TempDir() + "client.key"
validity := 1 * time.Second
writeTestCert(t, ci, validity)
// Now create the client with the above settings
// we need to disable TLS verification since we don't
// care about server certificate
client := NewClient(ctx)
tt := client.Transport.(*Transport)
tt.TLSClientConfig.InsecureSkipVerify = true
// Now make requests, the first request should be within
// the valid window
_, err := client.Get(ts.URL)
assert.NoError(t, err)
// Wait for the 2* valid duration of the certificate so that has definitely expired
time.Sleep(2 * validity)
// Create a new cert and write it to files
writeTestCert(t, ci, validity)
// The new cert should be auto-loaded before we make this request
_, err = client.Get(ts.URL)
assert.NoError(t, err)
}