certificates/cas/vaultcas/vaultcas.go

409 lines
11 KiB
Go

package vaultcas
import (
"bytes"
"context"
"crypto/sha256"
"crypto/x509"
"encoding/json"
"encoding/pem"
"math/big"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/cas/apiv1"
vault "github.com/hashicorp/vault/api"
auth "github.com/hashicorp/vault/api/auth/approle"
certutil "github.com/hashicorp/vault/sdk/helper/certutil"
)
func init() {
apiv1.Register(apiv1.VaultCAS, func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) {
return New(ctx, opts)
})
}
type VaultOptions struct {
PKI string `json:"pki,omitempty"`
PKIRoleDefault string `json:"PKIRoleDefault,omitempty"`
PKIRoleRSA string `json:"pkiRoleRSA,omitempty"`
PKIRoleEC string `json:"pkiRoleEC,omitempty"`
PKIRoleEd25519 string `json:"PKIRoleEd25519,omitempty"`
RoleID string `json:"roleID,omitempty"`
SecretID auth.SecretID `json:"secretID,omitempty"`
AppRole string `json:"appRole,omitempty"`
IsWrappingToken bool `json:"isWrappingToken,omitempty"`
}
// VaultCAS implements a Certificate Authority Service using Hashicorp Vault.
type VaultCAS struct {
client *vault.Client
config VaultOptions
fingerprint string
}
type Certificate struct {
leaf *x509.Certificate
intermediates []*x509.Certificate
root *x509.Certificate
}
func loadOptions(config json.RawMessage) (*VaultOptions, error) {
var vc *VaultOptions
err := json.Unmarshal(config, &vc)
if err != nil {
return nil, errors.Wrap(err, "error decoding vaultCAS config")
}
if vc.PKI == "" {
vc.PKI = "pki" // use default pki vault name
}
if vc.PKIRoleDefault == "" {
vc.PKIRoleDefault = "default" // use default pki role name
}
if vc.PKIRoleRSA == "" {
vc.PKIRoleRSA = vc.PKIRoleDefault
}
if vc.PKIRoleEC == "" {
vc.PKIRoleEC = vc.PKIRoleDefault
}
if vc.PKIRoleEd25519 == "" {
vc.PKIRoleEd25519 = vc.PKIRoleDefault
}
if vc.RoleID == "" {
return nil, errors.New("vaultCAS config options must define `roleID`")
}
if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" {
return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`")
}
if vc.PKI == "" {
vc.PKI = "pki" // use default pki vault name
}
if vc.AppRole == "" {
vc.AppRole = "auth/approle"
}
return vc, nil
}
func certificateSort(n []*x509.Certificate) bool {
// sort all cert using bubble sort
isSorted := false
s := 0
maxSwap := len(n) * (len(n) - 1) / 2
for s <= maxSwap && !isSorted {
isSorted = true
var i = 0
for i < len(n)-1 {
if !isSignedBy(n[i], n[i+1]) {
// swap
n[i], n[i+1] = n[i+1], n[i]
isSorted = false
}
i++
}
s++
}
return isSorted
}
func isSignedBy(i *x509.Certificate, j *x509.Certificate) bool {
signer := x509.NewCertPool()
signer.AddCert(j)
opts := x509.VerifyOptions{
Roots: signer,
Intermediates: x509.NewCertPool(), // set empty to avoid using system CA
}
_, err := i.Verify(opts)
return err == nil
}
func parseCertificates(pemCert string) []*x509.Certificate {
var certs []*x509.Certificate
rest := []byte(pemCert)
var block *pem.Block
for {
block, rest = pem.Decode(rest)
if block == nil {
break
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
break
}
certs = append(certs, cert)
}
return certs
}
func getCertificateAndChain(certb certutil.CertBundle) (*Certificate, error) {
// certutil.CertBundle contains CAChain and Certificate.
// Both could have a common part or different and we are not sure
// how user define their chain inside vault.
// We will create an array of certificate with all parsed certificates
// then sort the array to create a consistent chain
var root *x509.Certificate
var leaf *x509.Certificate
intermediates := make([]*x509.Certificate, 0)
used := make(map[string]bool) // ensure that intermediate are uniq
chains := append(certb.CAChain, []string{certb.Certificate}...)
for _, chain := range chains {
for _, cert := range parseCertificates(chain) {
if used[cert.SerialNumber.String()] == true {
continue
}
used[cert.SerialNumber.String()] = true
if cert.IsCA && bytes.Equal(cert.RawIssuer, cert.RawSubject) {
root = cert
} else if !cert.IsCA {
leaf = cert
} else {
intermediates = append(intermediates, cert)
}
}
}
if ok := certificateSort(intermediates); !ok {
return nil, errors.Errorf("failed to sort certificate, probably one of cert is not part of the chain")
}
certificate := &Certificate{
root: root,
leaf: leaf,
intermediates: intermediates,
}
return certificate, nil
}
func parseCertificateRequest(pemCsr string) (*x509.CertificateRequest, error) {
block, _ := pem.Decode([]byte(pemCsr))
if block == nil {
return nil, errors.Errorf("error decoding certificate request: not a valid PEM encoded block, please verify\r\n%v", pemCsr)
}
cr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "error parsing certificate request")
}
return cr, nil
}
func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) {
var vaultPKIRole string
switch {
case cr.PublicKeyAlgorithm == x509.RSA:
vaultPKIRole = v.config.PKIRoleRSA
case cr.PublicKeyAlgorithm == x509.ECDSA:
vaultPKIRole = v.config.PKIRoleEC
case cr.PublicKeyAlgorithm == x509.Ed25519:
vaultPKIRole = v.config.PKIRoleEd25519
default:
return nil, nil, errors.Errorf("createCertificate: Unsupported public key algorithm '%v'", cr.PublicKeyAlgorithm)
}
certPemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: cr.Raw})
if certPemBytes == nil {
return nil, nil, errors.Errorf("createCertificate: Failed to encode pem '%v'", cr.Raw)
}
y := map[string]interface{}{
"csr": string(certPemBytes),
"format": "pem_bundle",
"ttl": lifetime.Seconds(),
}
secret, err := v.client.Logical().Write(v.config.PKI+"/sign/"+vaultPKIRole, y)
if err != nil {
return nil, nil, errors.Wrapf(err, "createCertificate: unable to sign certificate %v", y)
}
if secret == nil {
return nil, nil, errors.New("createCertificate: secret sign is empty")
}
var certBundle certutil.CertBundle
secretData, err := json.Marshal(secret.Data)
if err != nil {
return nil, nil, err
}
err = json.Unmarshal(secretData, &certBundle)
if err != nil {
return nil, nil, err
}
cert, err := getCertificateAndChain(certBundle)
if err != nil {
return nil, nil, err
}
// Return certificate and certificate chain
return cert.leaf, cert.intermediates, nil
}
// New creates a new CertificateAuthorityService implementation
// using Hashicorp Vault
func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) {
if opts.CertificateAuthority == "" {
return nil, errors.New("vaultCAS 'certificateAuthority' cannot be empty")
}
if opts.CertificateAuthorityFingerprint == "" {
return nil, errors.New("vaultCAS 'certificateAuthorityFingerprint' cannot be empty")
}
vc, err := loadOptions(opts.Config)
if err != nil {
return nil, err
}
config := vault.DefaultConfig()
config.Address = opts.CertificateAuthority
client, err := vault.NewClient(config)
if err != nil {
return nil, errors.Wrap(err, "unable to initialize vault client")
}
var appRoleAuth *auth.AppRoleAuth
if vc.IsWrappingToken {
appRoleAuth, err = auth.NewAppRoleAuth(
vc.RoleID,
&vc.SecretID,
auth.WithWrappingToken(),
auth.WithMountPath(vc.AppRole),
)
} else {
appRoleAuth, err = auth.NewAppRoleAuth(
vc.RoleID,
&vc.SecretID,
auth.WithMountPath(vc.AppRole),
)
}
if err != nil {
return nil, errors.Wrap(err, "unable to initialize AppRole auth method")
}
authInfo, err := client.Auth().Login(ctx, appRoleAuth)
if err != nil {
return nil, errors.Wrap(err, "unable to login to AppRole auth method")
}
if authInfo == nil {
return nil, errors.New("no auth info was returned after login")
}
return &VaultCAS{
client: client,
config: *vc,
fingerprint: opts.CertificateAuthorityFingerprint,
}, nil
}
// CreateCertificate signs a new certificate using Hashicorp Vault.
func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) {
switch {
case req.CSR == nil:
return nil, errors.New("CreateCertificate: `CSR` cannot be nil")
case req.Lifetime == 0:
return nil, errors.New("CreateCertificate: `LIFETIME` cannot be 0")
}
cert, chain, err := v.createCertificate(req.CSR, req.Lifetime)
if err != nil {
return nil, err
}
return &apiv1.CreateCertificateResponse{
Certificate: cert,
CertificateChain: chain,
}, nil
}
func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) {
secret, err := v.client.Logical().Read(v.config.PKI + "/cert/ca_chain")
if err != nil {
return nil, errors.Wrap(err, "unable to read root")
}
if secret == nil {
return nil, errors.New("secret root is empty")
}
var certBundle certutil.CertBundle
secretData, err := json.Marshal(secret.Data)
if err != nil {
return nil, err
}
err = json.Unmarshal(secretData, &certBundle)
if err != nil {
return nil, err
}
cert, err := getCertificateAndChain(certBundle)
if err != nil {
return nil, err
}
sha256Sum := sha256.Sum256(cert.root.Raw)
expectedSum := certutil.GetHexFormatted(sha256Sum[:], "")
if expectedSum != v.fingerprint {
return nil, errors.Errorf("Vault Root CA fingerprint `%s` doesn't match config fingerprint `%v`", expectedSum, v.fingerprint)
}
return &apiv1.GetCertificateAuthorityResponse{
RootCertificate: cert.root,
}, nil
}
// RenewCertificate will always return a non-implemented error as renewals
// are not supported yet.
func (v *VaultCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) {
return nil, apiv1.ErrNotImplemented{Message: "vaultCAS does not support renewals"}
}
func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
if req.SerialNumber == "" && req.Certificate == nil {
return nil, errors.New("`serialNumber` or `certificate` are required")
}
var serialNumber []byte
if req.SerialNumber != "" {
// req.SerialNumber is a big.Int string representation
n := new(big.Int)
n, ok := n.SetString(req.SerialNumber, 10)
if !ok {
return nil, errors.Errorf("serialNumber `%v` can't be convert to big.Int", req.SerialNumber)
}
serialNumber = n.Bytes()
} else {
// req.Certificate.SerialNumber is a big.Int
serialNumber = req.Certificate.SerialNumber.Bytes()
}
serialNumberDash := certutil.GetHexFormatted(serialNumber, "-")
y := map[string]interface{}{
"serial_number": serialNumberDash,
}
_, err := v.client.Logical().Write(v.config.PKI+"/revoke/", y)
if err != nil {
return nil, errors.Wrap(err, "unable to revoke certificate")
}
return &apiv1.RevokeCertificateResponse{
Certificate: req.Certificate,
CertificateChain: nil,
}, nil
}