diff --git a/cas/vaultcas/vaultcas.go b/cas/vaultcas/vaultcas.go index 4d7e220d..a6b8d62c 100644 --- a/cas/vaultcas/vaultcas.go +++ b/cas/vaultcas/vaultcas.go @@ -1,12 +1,15 @@ package vaultcas import ( + "bytes" "context" "crypto/sha256" "crypto/x509" + "encoding/hex" "encoding/json" "encoding/pem" "math/big" + "strings" "time" "github.com/pkg/errors" @@ -14,7 +17,6 @@ import ( vault "github.com/hashicorp/vault/api" auth "github.com/hashicorp/vault/api/auth/approle" - certutil "github.com/hashicorp/vault/sdk/helper/certutil" ) func init() { @@ -28,7 +30,7 @@ type VaultOptions struct { PKIRoleDefault string `json:"PKIRoleDefault,omitempty"` PKIRoleRSA string `json:"pkiRoleRSA,omitempty"` PKIRoleEC string `json:"pkiRoleEC,omitempty"` - PKIRoleEd25519 string `json:"PKIRoleEd25519,omitempty"` + PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"` RoleID string `json:"roleID,omitempty"` SecretID auth.SecretID `json:"secretID,omitempty"` AppRole string `json:"appRole,omitempty"` @@ -42,207 +44,12 @@ type VaultCAS struct { fingerprint string } -type Certificate struct { +type certBundle struct { leaf *x509.Certificate intermediates []*x509.Certificate root *x509.Certificate } -func loadOptions(config json.RawMessage) (*VaultOptions, error) { - var vc *VaultOptions - - err := json.Unmarshal(config, &vc) - if err != nil { - return nil, errors.Wrap(err, "error decoding vaultCAS config") - } - - if vc.PKI == "" { - vc.PKI = "pki" // use default pki vault name - } - - if vc.PKIRoleDefault == "" { - vc.PKIRoleDefault = "default" // use default pki role name - } - - if vc.PKIRoleRSA == "" { - vc.PKIRoleRSA = vc.PKIRoleDefault - } - if vc.PKIRoleEC == "" { - vc.PKIRoleEC = vc.PKIRoleDefault - } - if vc.PKIRoleEd25519 == "" { - vc.PKIRoleEd25519 = vc.PKIRoleDefault - } - - if vc.RoleID == "" { - return nil, errors.New("vaultCAS config options must define `roleID`") - } - - if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" { - return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`") - } - - if vc.PKI == "" { - vc.PKI = "pki" // use default pki vault name - } - - if vc.AppRole == "" { - vc.AppRole = "auth/approle" - } - - return vc, nil -} - -func certificateSort(n []*x509.Certificate) bool { - // sort all cert using bubble sort - isSorted := false - s := 0 - maxSwap := len(n) * (len(n) - 1) / 2 - for s <= maxSwap && !isSorted { - isSorted = true - var i = 0 - for i < len(n)-1 { - if !isSignedBy(n[i], n[i+1]) { - // swap - n[i], n[i+1] = n[i+1], n[i] - isSorted = false - } - i++ - } - s++ - } - return isSorted -} - -func isSignedBy(i, j *x509.Certificate) bool { - signer := x509.NewCertPool() - signer.AddCert(j) - - opts := x509.VerifyOptions{ - Roots: signer, - Intermediates: x509.NewCertPool(), // set empty to avoid using system CA - } - _, err := i.Verify(opts) - return err == nil -} - -func parseCertificates(pemCert string) []*x509.Certificate { - var certs []*x509.Certificate - rest := []byte(pemCert) - var block *pem.Block - for { - block, rest = pem.Decode(rest) - if block == nil { - break - } - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - break - } - certs = append(certs, cert) - } - return certs -} - -func getCertificateAndChain(certb certutil.CertBundle) (*Certificate, error) { - // certutil.CertBundle contains CAChain and Certificate. - // Both could have a common part or different and we are not sure - // how user define their chain inside vault. - // We will create an array of certificate with all parsed certificates - // then sort the array to create a consistent chain - var root *x509.Certificate - var leaf *x509.Certificate - intermediates := make([]*x509.Certificate, 0) - used := make(map[string]bool) // ensure that intermediate are uniq - for _, chain := range append(certb.CAChain, certb.Certificate) { - for _, cert := range parseCertificates(chain) { - if used[cert.SerialNumber.String()] { - continue - } - used[cert.SerialNumber.String()] = true - switch { - case isRoot(cert): - root = cert - case cert.BasicConstraintsValid && cert.IsCA: - intermediates = append(intermediates, cert) - default: - leaf = cert - } - } - } - if ok := certificateSort(intermediates); !ok { - return nil, errors.Errorf("failed to sort certificate, probably one of cert is not part of the chain") - } - - certificate := &Certificate{ - root: root, - leaf: leaf, - intermediates: intermediates, - } - - return certificate, nil -} - -func parseCertificateRequest(pemCsr string) (*x509.CertificateRequest, error) { - block, _ := pem.Decode([]byte(pemCsr)) - if block == nil { - return nil, errors.Errorf("error decoding certificate request: not a valid PEM encoded block, please verify\r\n%v", pemCsr) - } - - cr, err := x509.ParseCertificateRequest(block.Bytes) - if err != nil { - return nil, errors.Wrap(err, "error parsing certificate request") - } - return cr, nil -} - -func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) { - var vaultPKIRole string - - switch { - case cr.PublicKeyAlgorithm == x509.RSA: - vaultPKIRole = v.config.PKIRoleRSA - case cr.PublicKeyAlgorithm == x509.ECDSA: - vaultPKIRole = v.config.PKIRoleEC - case cr.PublicKeyAlgorithm == x509.Ed25519: - vaultPKIRole = v.config.PKIRoleEd25519 - default: - return nil, nil, errors.Errorf("createCertificate: Unsupported public key algorithm '%v'", cr.PublicKeyAlgorithm) - } - - certPemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: cr.Raw}) - if certPemBytes == nil { - return nil, nil, errors.Errorf("createCertificate: Failed to encode pem '%v'", cr.Raw) - } - - y := map[string]interface{}{ - "csr": string(certPemBytes), - "format": "pem_bundle", - "ttl": lifetime.Seconds(), - } - - secret, err := v.client.Logical().Write(v.config.PKI+"/sign/"+vaultPKIRole, y) - if err != nil { - return nil, nil, errors.Wrapf(err, "createCertificate: unable to sign certificate %v", y) - } - if secret == nil { - return nil, nil, errors.New("createCertificate: secret sign is empty") - } - - var certBundle certutil.CertBundle - if err := unmarshalMap(secret.Data, &certBundle); err != nil { - return nil, nil, errors.Wrap(err, "error unmarshaling cert bundle") - } - - cert, err := getCertificateAndChain(certBundle) - if err != nil { - return nil, nil, err - } - - // Return certificate and certificate chain - return cert.leaf, cert.intermediates, nil -} - // New creates a new CertificateAuthorityService implementation // using Hashicorp Vault func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) { @@ -305,9 +112,9 @@ func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) { func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) { switch { case req.CSR == nil: - return nil, errors.New("CreateCertificate: `CSR` cannot be nil") + return nil, errors.New("createCertificate `csr` cannot be nil") case req.Lifetime == 0: - return nil, errors.New("CreateCertificate: `LIFETIME` cannot be 0") + return nil, errors.New("createCertificate `lifetime` cannot be 0") } cert, chain, err := v.createCertificate(req.CSR, req.Lifetime) @@ -324,26 +131,28 @@ func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) { secret, err := v.client.Logical().Read(v.config.PKI + "/cert/ca_chain") if err != nil { - return nil, errors.Wrap(err, "unable to read root") + return nil, errors.Wrap(err, "error reading ca chain") } if secret == nil { - return nil, errors.New("secret root is empty") + return nil, errors.New("error reading ca chain: response is empty") } - var certBundle certutil.CertBundle - if err := unmarshalMap(secret.Data, &certBundle); err != nil { - return nil, errors.Wrap(err, "error unmarshaling cert bundle") + chain, ok := secret.Data["certificate"].(string) + if !ok { + return nil, errors.New("error unmarshaling vault response: certificate not found") } - cert, err := getCertificateAndChain(certBundle) + cert, err := getCertificateBundle(chain) if err != nil { return nil, err } + if cert.root == nil { + return nil, errors.New("error unmarshaling vault response: root certificate not found") + } - sha256Sum := sha256.Sum256(cert.root.Raw) - expectedSum := certutil.GetHexFormatted(sha256Sum[:], "") - if expectedSum != v.fingerprint { - return nil, errors.Errorf("Vault Root CA fingerprint `%s` doesn't match config fingerprint `%v`", expectedSum, v.fingerprint) + sum := sha256.Sum256(cert.root.Raw) + if !strings.EqualFold(v.fingerprint, strings.ToLower(hex.EncodeToString(sum[:]))) { + return nil, errors.New("error verifying vault root: fingerprint does not match") } return &apiv1.GetCertificateAuthorityResponse{ @@ -357,34 +166,28 @@ func (v *VaultCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1. return nil, apiv1.ErrNotImplemented{Message: "vaultCAS does not support renewals"} } +// RevokeCertificate revokes a certificate by serial number. func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { if req.SerialNumber == "" && req.Certificate == nil { - return nil, errors.New("`serialNumber` or `certificate` are required") + return nil, errors.New("revokeCertificate `serialNumber` or `certificate` are required") } - var serialNumber []byte + var sn *big.Int if req.SerialNumber != "" { - // req.SerialNumber is a big.Int string representation - n := new(big.Int) - n, ok := n.SetString(req.SerialNumber, 10) - if !ok { - return nil, errors.Errorf("serialNumber `%v` can't be convert to big.Int", req.SerialNumber) + var ok bool + if sn, ok = new(big.Int).SetString(req.SerialNumber, 10); !ok { + return nil, errors.Errorf("error parsing serialNumber: %v cannot be converted to big.Int", req.SerialNumber) } - serialNumber = n.Bytes() } else { - // req.Certificate.SerialNumber is a big.Int - serialNumber = req.Certificate.SerialNumber.Bytes() + sn = req.Certificate.SerialNumber } - serialNumberDash := certutil.GetHexFormatted(serialNumber, "-") - - y := map[string]interface{}{ - "serial_number": serialNumberDash, + vaultReq := map[string]interface{}{ + "serial_number": formatSerialNumber(sn), } - - _, err := v.client.Logical().Write(v.config.PKI+"/revoke/", y) + _, err := v.client.Logical().Write(v.config.PKI+"/revoke/", vaultReq) if err != nil { - return nil, errors.Wrap(err, "unable to revoke certificate") + return nil, errors.Wrap(err, "error revoking certificate") } return &apiv1.RevokeCertificateResponse{ @@ -393,13 +196,136 @@ func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv }, nil } -func unmarshalMap(m map[string]interface{}, v interface{}) error { - b, err := json.Marshal(m) - if err != nil { - return err +func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) { + var vaultPKIRole string + + switch { + case cr.PublicKeyAlgorithm == x509.RSA: + vaultPKIRole = v.config.PKIRoleRSA + case cr.PublicKeyAlgorithm == x509.ECDSA: + vaultPKIRole = v.config.PKIRoleEC + case cr.PublicKeyAlgorithm == x509.Ed25519: + vaultPKIRole = v.config.PKIRoleEd25519 + default: + return nil, nil, errors.Errorf("unsupported public key algorithm '%v'", cr.PublicKeyAlgorithm) } - return json.Unmarshal(b, v) + vaultReq := map[string]interface{}{ + "csr": string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: cr.Raw, + })), + "format": "pem_bundle", + "ttl": lifetime.Seconds(), + } + + secret, err := v.client.Logical().Write(v.config.PKI+"/sign/"+vaultPKIRole, vaultReq) + if err != nil { + return nil, nil, errors.Wrap(err, "error signing certificate") + } + if secret == nil { + return nil, nil, errors.New("error signing certificate: response is empty") + } + + chain, ok := secret.Data["certificate"].(string) + if !ok { + return nil, nil, errors.New("error unmarshaling vault response: certificate not found") + } + + cert, err := getCertificateBundle(chain) + if err != nil { + return nil, nil, err + } + + // Return certificate and certificate chain + return cert.leaf, cert.intermediates, nil +} + +func loadOptions(config json.RawMessage) (*VaultOptions, error) { + var vc *VaultOptions + + err := json.Unmarshal(config, &vc) + if err != nil { + return nil, errors.Wrap(err, "error decoding vaultCAS config") + } + + if vc.PKI == "" { + vc.PKI = "pki" // use default pki vault name + } + + if vc.PKIRoleDefault == "" { + vc.PKIRoleDefault = "default" // use default pki role name + } + + if vc.PKIRoleRSA == "" { + vc.PKIRoleRSA = vc.PKIRoleDefault + } + if vc.PKIRoleEC == "" { + vc.PKIRoleEC = vc.PKIRoleDefault + } + if vc.PKIRoleEd25519 == "" { + vc.PKIRoleEd25519 = vc.PKIRoleDefault + } + + if vc.RoleID == "" { + return nil, errors.New("vaultCAS config options must define `roleID`") + } + + if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" { + return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`") + } + + if vc.PKI == "" { + vc.PKI = "pki" // use default pki vault name + } + + if vc.AppRole == "" { + vc.AppRole = "auth/approle" + } + + return vc, nil +} + +func parseCertificates(pemCert string) []*x509.Certificate { + var certs []*x509.Certificate + rest := []byte(pemCert) + var block *pem.Block + for { + block, rest = pem.Decode(rest) + if block == nil { + break + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + break + } + certs = append(certs, cert) + } + return certs +} + +func getCertificateBundle(chain string) (*certBundle, error) { + var root *x509.Certificate + var leaf *x509.Certificate + var intermediates []*x509.Certificate + for _, cert := range parseCertificates(chain) { + switch { + case isRoot(cert): + root = cert + case cert.BasicConstraintsValid && cert.IsCA: + intermediates = append(intermediates, cert) + default: + leaf = cert + } + } + + certificate := &certBundle{ + root: root, + leaf: leaf, + intermediates: intermediates, + } + + return certificate, nil } // isRoot returns true if the given certificate is a root certificate. @@ -409,3 +335,16 @@ func isRoot(cert *x509.Certificate) bool { } return false } + +// formatSerialNumber formats a serial number to a dash-separated hexadecimal +// string. +func formatSerialNumber(sn *big.Int) string { + var ret bytes.Buffer + for _, b := range sn.Bytes() { + if ret.Len() > 0 { + ret.WriteString("-") + } + ret.WriteString(hex.EncodeToString([]byte{b})) + } + return ret.String() +} diff --git a/cas/vaultcas/vaultcas_test.go b/cas/vaultcas/vaultcas_test.go index 1febf1ce..9f73a1ee 100644 --- a/cas/vaultcas/vaultcas_test.go +++ b/cas/vaultcas/vaultcas_test.go @@ -16,6 +16,7 @@ import ( vault "github.com/hashicorp/vault/api" auth "github.com/hashicorp/vault/api/auth/approle" "github.com/smallstep/certificates/cas/apiv1" + "go.step.sm/crypto/pemutil" ) var ( @@ -80,13 +81,13 @@ func mustParseCertificate(t *testing.T, pemCert string) *x509.Certificate { return crt } -func mustParseCertificateRequest(t *testing.T, pemCert string) *x509.CertificateRequest { +func mustParseCertificateRequest(t *testing.T, pemData string) *x509.CertificateRequest { t.Helper() - crt, err := parseCertificateRequest(pemCert) + csr, err := pemutil.ParseCertificateRequest([]byte(pemData)) if err != nil { t.Fatal(err) } - return crt + return csr } func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { @@ -107,17 +108,17 @@ func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { }`) case r.RequestURI == "/v1/pki/sign/ec": w.WriteHeader(http.StatusOK) - cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned}} + cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} writeJSON(w, cert) return case r.RequestURI == "/v1/pki/sign/rsa": w.WriteHeader(http.StatusOK) - cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned}} + cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} writeJSON(w, cert) return case r.RequestURI == "/v1/pki/sign/ed25519": w.WriteHeader(http.StatusOK) - cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned}} + cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} writeJSON(w, cert) return case r.RequestURI == "/v1/pki/cert/ca_chain": @@ -232,21 +233,21 @@ func TestVaultCAS_CreateCertificate(t *testing.T) { Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: mustParseCertificate(t, testCertificateSigned), - CertificateChain: []*x509.Certificate{}, + CertificateChain: nil, }, false}, {"ok rsa", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrRsa), Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: mustParseCertificate(t, testCertificateSigned), - CertificateChain: []*x509.Certificate{}, + CertificateChain: nil, }, false}, {"ok ed25519", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrEd25519), Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: mustParseCertificate(t, testCertificateSigned), - CertificateChain: []*x509.Certificate{}, + CertificateChain: nil, }, false}, {"fail CSR", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: nil,