From 510aa8f871a243d8df90d853c10b5ae93bd4a54d Mon Sep 17 00:00:00 2001 From: Aofei Sheng Date: Sat, 6 May 2023 06:14:07 +0800 Subject: [PATCH] fix: archive only domain-related files on revoke (#1874) Co-authored-by: Fernandez Ludovic --- cmd/certs_storage.go | 31 +++++++--- cmd/certs_storage_test.go | 115 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 8 deletions(-) create mode 100644 cmd/certs_storage_test.go diff --git a/cmd/certs_storage.go b/cmd/certs_storage.go index ed18018c..3ddf0b10 100644 --- a/cmd/certs_storage.go +++ b/cmd/certs_storage.go @@ -27,6 +27,15 @@ const ( baseArchivesFolderName = "archives" ) +const ( + issuerExt = ".issuer.crt" + certExt = ".crt" + keyExt = ".key" + pemExt = ".pem" + pfxExt = ".pfx" + resourceExt = ".json" +) + // CertificatesStorage a certificates' storage. // // rootPath: @@ -84,13 +93,13 @@ func (s *CertificatesStorage) SaveResource(certRes *certificate.Resource) { // We store the certificate, private key and metadata in different files // as web servers would not be able to work with a combined file. - err := s.WriteFile(domain, ".crt", certRes.Certificate) + err := s.WriteFile(domain, certExt, certRes.Certificate) if err != nil { log.Fatalf("Unable to save Certificate for domain %s\n\t%v", domain, err) } if certRes.IssuerCertificate != nil { - err = s.WriteFile(domain, ".issuer.crt", certRes.IssuerCertificate) + err = s.WriteFile(domain, issuerExt, certRes.IssuerCertificate) if err != nil { log.Fatalf("Unable to save IssuerCertificate for domain %s\n\t%v", domain, err) } @@ -112,14 +121,14 @@ func (s *CertificatesStorage) SaveResource(certRes *certificate.Resource) { log.Fatalf("Unable to marshal CertResource for domain %s\n\t%v", domain, err) } - err = s.WriteFile(domain, ".json", jsonBytes) + err = s.WriteFile(domain, resourceExt, jsonBytes) if err != nil { log.Fatalf("Unable to save CertResource for domain %s\n\t%v", domain, err) } } func (s *CertificatesStorage) ReadResource(domain string) certificate.Resource { - raw, err := s.ReadFile(domain, ".json") + raw, err := s.ReadFile(domain, resourceExt) if err != nil { log.Fatalf("Error while loading the meta data for domain %s\n\t%v", domain, err) } @@ -176,13 +185,13 @@ func (s *CertificatesStorage) WriteFile(domain, extension string, data []byte) e } func (s *CertificatesStorage) WriteCertificateFiles(domain string, certRes *certificate.Resource) error { - err := s.WriteFile(domain, ".key", certRes.PrivateKey) + err := s.WriteFile(domain, keyExt, certRes.PrivateKey) if err != nil { return fmt.Errorf("unable to save key file: %w", err) } if s.pem { - err = s.WriteFile(domain, ".pem", bytes.Join([][]byte{certRes.Certificate, certRes.PrivateKey}, nil)) + err = s.WriteFile(domain, pemExt, bytes.Join([][]byte{certRes.Certificate, certRes.PrivateKey}, nil)) if err != nil { return fmt.Errorf("unable to save PEM file: %w", err) } @@ -247,16 +256,22 @@ func (s *CertificatesStorage) WritePFXFile(domain string, certRes *certificate.R return fmt.Errorf("unable to encode PFX data for domain %s: %w", domain, err) } - return s.WriteFile(domain, ".pfx", pfxBytes) + return s.WriteFile(domain, pfxExt, pfxBytes) } func (s *CertificatesStorage) MoveToArchive(domain string) error { - matches, err := filepath.Glob(filepath.Join(s.rootPath, sanitizedDomain(domain)+".*")) + baseFilename := filepath.Join(s.rootPath, sanitizedDomain(domain)) + + matches, err := filepath.Glob(baseFilename + ".*") if err != nil { return err } for _, oldFile := range matches { + if strings.TrimSuffix(oldFile, filepath.Ext(oldFile)) != baseFilename && oldFile != baseFilename+issuerExt { + continue + } + date := strconv.FormatInt(time.Now().Unix(), 10) filename := date + "." + filepath.Base(oldFile) newFile := filepath.Join(s.archivePath, filename) diff --git a/cmd/certs_storage_test.go b/cmd/certs_storage_test.go new file mode 100644 index 00000000..9a474f18 --- /dev/null +++ b/cmd/certs_storage_test.go @@ -0,0 +1,115 @@ +package cmd + +import ( + "os" + "path/filepath" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCertificatesStorage_MoveToArchive(t *testing.T) { + domain := "example.com" + + storage := CertificatesStorage{ + rootPath: t.TempDir(), + archivePath: t.TempDir(), + } + + domainFiles := generateTestFiles(t, storage.rootPath, domain) + + err := storage.MoveToArchive(domain) + require.NoError(t, err) + + for _, file := range domainFiles { + assert.NoFileExists(t, file) + } + + root, err := os.ReadDir(storage.rootPath) + require.NoError(t, err) + require.Empty(t, root) + + archive, err := os.ReadDir(storage.archivePath) + require.NoError(t, err) + + require.Len(t, archive, len(domainFiles)) + assert.Regexp(t, `\d+\.`+regexp.QuoteMeta(domain), archive[0].Name()) +} + +func TestCertificatesStorage_MoveToArchive_noFileRelatedToDomain(t *testing.T) { + domain := "example.com" + + storage := CertificatesStorage{ + rootPath: t.TempDir(), + archivePath: t.TempDir(), + } + + domainFiles := generateTestFiles(t, storage.rootPath, "example.org") + + err := storage.MoveToArchive(domain) + require.NoError(t, err) + + for _, file := range domainFiles { + assert.FileExists(t, file) + } + + root, err := os.ReadDir(storage.rootPath) + require.NoError(t, err) + assert.Len(t, root, len(domainFiles)) + + archive, err := os.ReadDir(storage.archivePath) + require.NoError(t, err) + + assert.Empty(t, archive) +} + +func TestCertificatesStorage_MoveToArchive_ambiguousDomain(t *testing.T) { + domain := "example.com" + + storage := CertificatesStorage{ + rootPath: t.TempDir(), + archivePath: t.TempDir(), + } + + domainFiles := generateTestFiles(t, storage.rootPath, domain) + otherDomainFiles := generateTestFiles(t, storage.rootPath, domain+".example.org") + + err := storage.MoveToArchive(domain) + require.NoError(t, err) + + for _, file := range domainFiles { + assert.NoFileExists(t, file) + } + + for _, file := range otherDomainFiles { + assert.FileExists(t, file) + } + + root, err := os.ReadDir(storage.rootPath) + require.NoError(t, err) + require.Len(t, root, len(otherDomainFiles)) + + archive, err := os.ReadDir(storage.archivePath) + require.NoError(t, err) + + require.Len(t, archive, len(domainFiles)) + assert.Regexp(t, `\d+\.`+regexp.QuoteMeta(domain), archive[0].Name()) +} + +func generateTestFiles(t *testing.T, dir, domain string) []string { + t.Helper() + + var filenames []string + + for _, ext := range []string{issuerExt, certExt, keyExt, pemExt, pfxExt, resourceExt} { + filename := filepath.Join(dir, domain+ext) + err := os.WriteFile(filename, []byte("test"), 0o666) + require.NoError(t, err) + + filenames = append(filenames, filename) + } + + return filenames +}