forked from TrueCloudLab/certificates
355 lines
9.5 KiB
Go
355 lines
9.5 KiB
Go
package vaultcas
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"crypto/x509"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/smallstep/certificates/cas/apiv1"
|
|
|
|
vault "github.com/hashicorp/vault/api"
|
|
auth "github.com/hashicorp/vault/api/auth/approle"
|
|
)
|
|
|
|
func init() {
|
|
apiv1.Register(apiv1.VaultCAS, func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) {
|
|
return New(ctx, opts)
|
|
})
|
|
}
|
|
|
|
// VaultOptions defines the configuration options added using the
|
|
// apiv1.Options.Config field.
|
|
type VaultOptions struct {
|
|
PKI string `json:"pki,omitempty"`
|
|
PKIRoleDefault string `json:"pkiRoleDefault,omitempty"`
|
|
PKIRoleRSA string `json:"pkiRoleRSA,omitempty"`
|
|
PKIRoleEC string `json:"pkiRoleEC,omitempty"`
|
|
PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"`
|
|
RoleID string `json:"roleID,omitempty"`
|
|
SecretID auth.SecretID `json:"secretID,omitempty"`
|
|
AppRole string `json:"appRole,omitempty"`
|
|
IsWrappingToken bool `json:"isWrappingToken,omitempty"`
|
|
}
|
|
|
|
// VaultCAS implements a Certificate Authority Service using Hashicorp Vault.
|
|
type VaultCAS struct {
|
|
client *vault.Client
|
|
config VaultOptions
|
|
fingerprint string
|
|
}
|
|
|
|
type certBundle struct {
|
|
leaf *x509.Certificate
|
|
intermediates []*x509.Certificate
|
|
root *x509.Certificate
|
|
}
|
|
|
|
// New creates a new CertificateAuthorityService implementation
|
|
// using Hashicorp Vault
|
|
func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) {
|
|
if opts.CertificateAuthority == "" {
|
|
return nil, errors.New("vaultCAS 'certificateAuthority' cannot be empty")
|
|
}
|
|
|
|
if opts.CertificateAuthorityFingerprint == "" {
|
|
return nil, errors.New("vaultCAS 'certificateAuthorityFingerprint' cannot be empty")
|
|
}
|
|
|
|
vc, err := loadOptions(opts.Config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
config := vault.DefaultConfig()
|
|
config.Address = opts.CertificateAuthority
|
|
|
|
client, err := vault.NewClient(config)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to initialize vault client: %w", err)
|
|
}
|
|
|
|
var appRoleAuth *auth.AppRoleAuth
|
|
if vc.IsWrappingToken {
|
|
appRoleAuth, err = auth.NewAppRoleAuth(
|
|
vc.RoleID,
|
|
&vc.SecretID,
|
|
auth.WithWrappingToken(),
|
|
auth.WithMountPath(vc.AppRole),
|
|
)
|
|
} else {
|
|
appRoleAuth, err = auth.NewAppRoleAuth(
|
|
vc.RoleID,
|
|
&vc.SecretID,
|
|
auth.WithMountPath(vc.AppRole),
|
|
)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to initialize AppRole auth method: %w", err)
|
|
}
|
|
|
|
authInfo, err := client.Auth().Login(ctx, appRoleAuth)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to login to AppRole auth method: %w", err)
|
|
}
|
|
if authInfo == nil {
|
|
return nil, errors.New("no auth info was returned after login")
|
|
}
|
|
|
|
return &VaultCAS{
|
|
client: client,
|
|
config: *vc,
|
|
fingerprint: opts.CertificateAuthorityFingerprint,
|
|
}, nil
|
|
}
|
|
|
|
// CreateCertificate signs a new certificate using Hashicorp Vault.
|
|
func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) {
|
|
switch {
|
|
case req.CSR == nil:
|
|
return nil, errors.New("createCertificate `csr` cannot be nil")
|
|
case req.Lifetime == 0:
|
|
return nil, errors.New("createCertificate `lifetime` cannot be 0")
|
|
}
|
|
|
|
cert, chain, err := v.createCertificate(req.CSR, req.Lifetime)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &apiv1.CreateCertificateResponse{
|
|
Certificate: cert,
|
|
CertificateChain: chain,
|
|
}, nil
|
|
}
|
|
|
|
// GetCertificateAuthority returns the root certificate of the certificate
|
|
// authority using the configured fingerprint.
|
|
func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) {
|
|
secret, err := v.client.Logical().Read(v.config.PKI + "/cert/ca_chain")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading ca chain: %w", err)
|
|
}
|
|
if secret == nil {
|
|
return nil, errors.New("error reading ca chain: response is empty")
|
|
}
|
|
|
|
chain, ok := secret.Data["certificate"].(string)
|
|
if !ok {
|
|
return nil, errors.New("error unmarshaling vault response: certificate not found")
|
|
}
|
|
|
|
cert, err := getCertificateBundle(chain)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if cert.root == nil {
|
|
return nil, errors.New("error unmarshaling vault response: root certificate not found")
|
|
}
|
|
|
|
sum := sha256.Sum256(cert.root.Raw)
|
|
if !strings.EqualFold(v.fingerprint, strings.ToLower(hex.EncodeToString(sum[:]))) {
|
|
return nil, errors.New("error verifying vault root: fingerprint does not match")
|
|
}
|
|
|
|
return &apiv1.GetCertificateAuthorityResponse{
|
|
RootCertificate: cert.root,
|
|
}, nil
|
|
}
|
|
|
|
// RenewCertificate will always return a non-implemented error as renewals
|
|
// are not supported yet.
|
|
func (v *VaultCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) {
|
|
return nil, apiv1.ErrNotImplemented{Message: "vaultCAS does not support renewals"}
|
|
}
|
|
|
|
// RevokeCertificate revokes a certificate by serial number.
|
|
func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
|
|
if req.SerialNumber == "" && req.Certificate == nil {
|
|
return nil, errors.New("revokeCertificate `serialNumber` or `certificate` are required")
|
|
}
|
|
|
|
var sn *big.Int
|
|
if req.SerialNumber != "" {
|
|
var ok bool
|
|
if sn, ok = new(big.Int).SetString(req.SerialNumber, 10); !ok {
|
|
return nil, fmt.Errorf("error parsing serialNumber: %v cannot be converted to big.Int", req.SerialNumber)
|
|
}
|
|
} else {
|
|
sn = req.Certificate.SerialNumber
|
|
}
|
|
|
|
vaultReq := map[string]interface{}{
|
|
"serial_number": formatSerialNumber(sn),
|
|
}
|
|
_, err := v.client.Logical().Write(v.config.PKI+"/revoke/", vaultReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error revoking certificate: %w", err)
|
|
}
|
|
|
|
return &apiv1.RevokeCertificateResponse{
|
|
Certificate: req.Certificate,
|
|
CertificateChain: nil,
|
|
}, nil
|
|
}
|
|
|
|
func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) {
|
|
var vaultPKIRole string
|
|
|
|
switch {
|
|
case cr.PublicKeyAlgorithm == x509.RSA:
|
|
vaultPKIRole = v.config.PKIRoleRSA
|
|
case cr.PublicKeyAlgorithm == x509.ECDSA:
|
|
vaultPKIRole = v.config.PKIRoleEC
|
|
case cr.PublicKeyAlgorithm == x509.Ed25519:
|
|
vaultPKIRole = v.config.PKIRoleEd25519
|
|
default:
|
|
return nil, nil, fmt.Errorf("unsupported public key algorithm %v", cr.PublicKeyAlgorithm)
|
|
}
|
|
|
|
vaultReq := map[string]interface{}{
|
|
"csr": string(pem.EncodeToMemory(&pem.Block{
|
|
Type: "CERTIFICATE REQUEST",
|
|
Bytes: cr.Raw,
|
|
})),
|
|
"format": "pem_bundle",
|
|
"ttl": lifetime.Seconds(),
|
|
}
|
|
|
|
secret, err := v.client.Logical().Write(v.config.PKI+"/sign/"+vaultPKIRole, vaultReq)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("error signing certificate: %w", err)
|
|
}
|
|
if secret == nil {
|
|
return nil, nil, errors.New("error signing certificate: response is empty")
|
|
}
|
|
|
|
chain, ok := secret.Data["certificate"].(string)
|
|
if !ok {
|
|
return nil, nil, errors.New("error unmarshaling vault response: certificate not found")
|
|
}
|
|
|
|
cert, err := getCertificateBundle(chain)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Return certificate and certificate chain
|
|
return cert.leaf, cert.intermediates, nil
|
|
}
|
|
|
|
func loadOptions(config json.RawMessage) (*VaultOptions, error) {
|
|
var vc *VaultOptions
|
|
|
|
err := json.Unmarshal(config, &vc)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error decoding vaultCAS config: %w", err)
|
|
}
|
|
|
|
if vc.PKI == "" {
|
|
vc.PKI = "pki" // use default pki vault name
|
|
}
|
|
|
|
if vc.PKIRoleDefault == "" {
|
|
vc.PKIRoleDefault = "default" // use default pki role name
|
|
}
|
|
|
|
if vc.PKIRoleRSA == "" {
|
|
vc.PKIRoleRSA = vc.PKIRoleDefault
|
|
}
|
|
if vc.PKIRoleEC == "" {
|
|
vc.PKIRoleEC = vc.PKIRoleDefault
|
|
}
|
|
if vc.PKIRoleEd25519 == "" {
|
|
vc.PKIRoleEd25519 = vc.PKIRoleDefault
|
|
}
|
|
|
|
if vc.RoleID == "" {
|
|
return nil, errors.New("vaultCAS config options must define `roleID`")
|
|
}
|
|
|
|
if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" {
|
|
return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`")
|
|
}
|
|
|
|
if vc.PKI == "" {
|
|
vc.PKI = "pki" // use default pki vault name
|
|
}
|
|
|
|
if vc.AppRole == "" {
|
|
vc.AppRole = "auth/approle"
|
|
}
|
|
|
|
return vc, nil
|
|
}
|
|
|
|
func parseCertificates(pemCert string) []*x509.Certificate {
|
|
var certs []*x509.Certificate
|
|
rest := []byte(pemCert)
|
|
var block *pem.Block
|
|
for {
|
|
block, rest = pem.Decode(rest)
|
|
if block == nil {
|
|
break
|
|
}
|
|
cert, err := x509.ParseCertificate(block.Bytes)
|
|
if err != nil {
|
|
break
|
|
}
|
|
certs = append(certs, cert)
|
|
}
|
|
return certs
|
|
}
|
|
|
|
func getCertificateBundle(chain string) (*certBundle, error) {
|
|
var root *x509.Certificate
|
|
var leaf *x509.Certificate
|
|
var intermediates []*x509.Certificate
|
|
for _, cert := range parseCertificates(chain) {
|
|
switch {
|
|
case isRoot(cert):
|
|
root = cert
|
|
case cert.BasicConstraintsValid && cert.IsCA:
|
|
intermediates = append(intermediates, cert)
|
|
default:
|
|
leaf = cert
|
|
}
|
|
}
|
|
|
|
certificate := &certBundle{
|
|
root: root,
|
|
leaf: leaf,
|
|
intermediates: intermediates,
|
|
}
|
|
|
|
return certificate, nil
|
|
}
|
|
|
|
// isRoot returns true if the given certificate is a root certificate.
|
|
func isRoot(cert *x509.Certificate) bool {
|
|
if cert.BasicConstraintsValid && cert.IsCA {
|
|
return cert.CheckSignatureFrom(cert) == nil
|
|
}
|
|
return false
|
|
}
|
|
|
|
// formatSerialNumber formats a serial number to a dash-separated hexadecimal
|
|
// string.
|
|
func formatSerialNumber(sn *big.Int) string {
|
|
var ret bytes.Buffer
|
|
for _, b := range sn.Bytes() {
|
|
if ret.Len() > 0 {
|
|
ret.WriteString("-")
|
|
}
|
|
ret.WriteString(hex.EncodeToString([]byte{b}))
|
|
}
|
|
return ret.String()
|
|
}
|