certificates/cas/vaultcas/vaultcas.go
2023-05-09 23:47:28 -07:00

329 lines
9.1 KiB
Go

package vaultcas
import (
"bytes"
"context"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"math/big"
"strings"
"time"
"github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/cas/vaultcas/auth/approle"
"github.com/smallstep/certificates/cas/vaultcas/auth/kubernetes"
vault "github.com/hashicorp/vault/api"
)
func init() {
apiv1.Register(apiv1.VaultCAS, func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) {
return New(ctx, opts)
})
}
// VaultOptions defines the configuration options added using the
// apiv1.Options.Config field.
type VaultOptions struct {
PKIMountPath string `json:"pkiMountPath,omitempty"`
PKIRoleDefault string `json:"pkiRoleDefault,omitempty"`
PKIRoleRSA string `json:"pkiRoleRSA,omitempty"`
PKIRoleEC string `json:"pkiRoleEC,omitempty"`
PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"`
AuthType string `json:"authType,omitempty"`
AuthMountPath string `json:"authMountPath,omitempty"`
AuthOptions json.RawMessage `json:"authOptions,omitempty"`
}
// VaultCAS implements a Certificate Authority Service using Hashicorp Vault.
type VaultCAS struct {
client *vault.Client
config VaultOptions
fingerprint string
}
type certBundle struct {
leaf *x509.Certificate
intermediates []*x509.Certificate
root *x509.Certificate
}
// New creates a new CertificateAuthorityService implementation
// using Hashicorp Vault
func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) {
if opts.CertificateAuthority == "" {
return nil, errors.New("vaultCAS 'certificateAuthority' cannot be empty")
}
if opts.CertificateAuthorityFingerprint == "" {
return nil, errors.New("vaultCAS 'certificateAuthorityFingerprint' cannot be empty")
}
vc, err := loadOptions(opts.Config)
if err != nil {
return nil, err
}
config := vault.DefaultConfig()
config.Address = opts.CertificateAuthority
client, err := vault.NewClient(config)
if err != nil {
return nil, fmt.Errorf("unable to initialize vault client: %w", err)
}
var method vault.AuthMethod
switch vc.AuthType {
case "kubernetes":
method, err = kubernetes.NewKubernetesAuthMethod(vc.AuthMountPath, vc.AuthOptions)
case "approle":
method, err = approle.NewApproleAuthMethod(vc.AuthMountPath, vc.AuthOptions)
default:
return nil, fmt.Errorf("unknown auth type: %s, only 'kubernetes' and 'approle' currently supported", vc.AuthType)
}
if err != nil {
return nil, fmt.Errorf("unable to configure %s auth method: %w", vc.AuthType, err)
}
authInfo, err := client.Auth().Login(ctx, method)
if err != nil {
return nil, fmt.Errorf("unable to login to %s auth method: %w", vc.AuthType, err)
}
if authInfo == nil {
return nil, errors.New("no auth info was returned after login")
}
return &VaultCAS{
client: client,
config: *vc,
fingerprint: opts.CertificateAuthorityFingerprint,
}, nil
}
// CreateCertificate signs a new certificate using Hashicorp Vault.
func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) {
switch {
case req.CSR == nil:
return nil, errors.New("createCertificate `csr` cannot be nil")
case req.Lifetime == 0:
return nil, errors.New("createCertificate `lifetime` cannot be 0")
}
cert, chain, err := v.createCertificate(req.CSR, req.Lifetime)
if err != nil {
return nil, err
}
return &apiv1.CreateCertificateResponse{
Certificate: cert,
CertificateChain: chain,
}, nil
}
// GetCertificateAuthority returns the root certificate of the certificate
// authority using the configured fingerprint.
func (v *VaultCAS) GetCertificateAuthority(*apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) {
secret, err := v.client.Logical().Read(v.config.PKIMountPath + "/cert/ca_chain")
if err != nil {
return nil, fmt.Errorf("error reading ca chain: %w", err)
}
if secret == nil {
return nil, errors.New("error reading ca chain: response is empty")
}
chain, ok := secret.Data["certificate"].(string)
if !ok {
return nil, errors.New("error unmarshaling vault response: certificate not found")
}
cert, err := getCertificateBundle(chain)
if err != nil {
return nil, err
}
if cert.root == nil {
return nil, errors.New("error unmarshaling vault response: root certificate not found")
}
sum := sha256.Sum256(cert.root.Raw)
if !strings.EqualFold(v.fingerprint, strings.ToLower(hex.EncodeToString(sum[:]))) {
return nil, errors.New("error verifying vault root: fingerprint does not match")
}
return &apiv1.GetCertificateAuthorityResponse{
RootCertificate: cert.root,
}, nil
}
// RenewCertificate will always return a non-implemented error as renewals
// are not supported yet.
func (v *VaultCAS) RenewCertificate(*apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) {
return nil, apiv1.NotImplementedError{Message: "vaultCAS does not support renewals"}
}
// RevokeCertificate revokes a certificate by serial number.
func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
if req.SerialNumber == "" && req.Certificate == nil {
return nil, errors.New("revokeCertificate `serialNumber` or `certificate` are required")
}
var sn *big.Int
if req.SerialNumber != "" {
var ok bool
if sn, ok = new(big.Int).SetString(req.SerialNumber, 10); !ok {
return nil, fmt.Errorf("error parsing serialNumber: %v cannot be converted to big.Int", req.SerialNumber)
}
} else {
sn = req.Certificate.SerialNumber
}
vaultReq := map[string]interface{}{
"serial_number": formatSerialNumber(sn),
}
_, err := v.client.Logical().Write(v.config.PKIMountPath+"/revoke/", vaultReq)
if err != nil {
return nil, fmt.Errorf("error revoking certificate: %w", err)
}
return &apiv1.RevokeCertificateResponse{
Certificate: req.Certificate,
CertificateChain: nil,
}, nil
}
func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) {
var vaultPKIRole string
switch {
case cr.PublicKeyAlgorithm == x509.RSA:
vaultPKIRole = v.config.PKIRoleRSA
case cr.PublicKeyAlgorithm == x509.ECDSA:
vaultPKIRole = v.config.PKIRoleEC
case cr.PublicKeyAlgorithm == x509.Ed25519:
vaultPKIRole = v.config.PKIRoleEd25519
default:
return nil, nil, fmt.Errorf("unsupported public key algorithm %v", cr.PublicKeyAlgorithm)
}
vaultReq := map[string]interface{}{
"csr": string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: cr.Raw,
})),
"format": "pem_bundle",
"ttl": lifetime.String(),
}
secret, err := v.client.Logical().Write(v.config.PKIMountPath+"/sign/"+vaultPKIRole, vaultReq)
if err != nil {
return nil, nil, fmt.Errorf("error signing certificate: %w", err)
}
if secret == nil {
return nil, nil, errors.New("error signing certificate: response is empty")
}
chain, ok := secret.Data["certificate"].(string)
if !ok {
return nil, nil, errors.New("error unmarshaling vault response: certificate not found")
}
cert, err := getCertificateBundle(chain)
if err != nil {
return nil, nil, err
}
// Return certificate and certificate chain
return cert.leaf, cert.intermediates, nil
}
func loadOptions(config json.RawMessage) (*VaultOptions, error) {
// setup default values
vc := VaultOptions{
PKIMountPath: "pki",
PKIRoleDefault: "default",
}
err := json.Unmarshal(config, &vc)
if err != nil {
return nil, fmt.Errorf("error decoding vaultCAS config: %w", err)
}
if vc.PKIRoleRSA == "" {
vc.PKIRoleRSA = vc.PKIRoleDefault
}
if vc.PKIRoleEC == "" {
vc.PKIRoleEC = vc.PKIRoleDefault
}
if vc.PKIRoleEd25519 == "" {
vc.PKIRoleEd25519 = vc.PKIRoleDefault
}
return &vc, nil
}
func parseCertificates(pemCert string) []*x509.Certificate {
var certs []*x509.Certificate
rest := []byte(pemCert)
var block *pem.Block
for {
block, rest = pem.Decode(rest)
if block == nil {
break
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
break
}
certs = append(certs, cert)
}
return certs
}
func getCertificateBundle(chain string) (*certBundle, error) {
var root *x509.Certificate
var leaf *x509.Certificate
var intermediates []*x509.Certificate
for _, cert := range parseCertificates(chain) {
switch {
case isRoot(cert):
root = cert
case cert.BasicConstraintsValid && cert.IsCA:
intermediates = append(intermediates, cert)
default:
leaf = cert
}
}
certificate := &certBundle{
root: root,
leaf: leaf,
intermediates: intermediates,
}
return certificate, nil
}
// isRoot returns true if the given certificate is a root certificate.
func isRoot(cert *x509.Certificate) bool {
if cert.BasicConstraintsValid && cert.IsCA {
return cert.CheckSignatureFrom(cert) == nil
}
return false
}
// formatSerialNumber formats a serial number to a dash-separated hexadecimal
// string.
func formatSerialNumber(sn *big.Int) string {
var ret bytes.Buffer
for _, b := range sn.Bytes() {
if ret.Len() > 0 {
ret.WriteString("-")
}
ret.WriteString(hex.EncodeToString([]byte{b}))
}
return ret.String()
}