Cleanup Vault CAS integration
This commit is contained in:
parent
9134bad22c
commit
967d9136ca
2 changed files with 181 additions and 241 deletions
|
@ -1,12 +1,15 @@
|
||||||
package vaultcas
|
package vaultcas
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -14,7 +17,6 @@ import (
|
||||||
|
|
||||||
vault "github.com/hashicorp/vault/api"
|
vault "github.com/hashicorp/vault/api"
|
||||||
auth "github.com/hashicorp/vault/api/auth/approle"
|
auth "github.com/hashicorp/vault/api/auth/approle"
|
||||||
certutil "github.com/hashicorp/vault/sdk/helper/certutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -28,7 +30,7 @@ type VaultOptions struct {
|
||||||
PKIRoleDefault string `json:"PKIRoleDefault,omitempty"`
|
PKIRoleDefault string `json:"PKIRoleDefault,omitempty"`
|
||||||
PKIRoleRSA string `json:"pkiRoleRSA,omitempty"`
|
PKIRoleRSA string `json:"pkiRoleRSA,omitempty"`
|
||||||
PKIRoleEC string `json:"pkiRoleEC,omitempty"`
|
PKIRoleEC string `json:"pkiRoleEC,omitempty"`
|
||||||
PKIRoleEd25519 string `json:"PKIRoleEd25519,omitempty"`
|
PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"`
|
||||||
RoleID string `json:"roleID,omitempty"`
|
RoleID string `json:"roleID,omitempty"`
|
||||||
SecretID auth.SecretID `json:"secretID,omitempty"`
|
SecretID auth.SecretID `json:"secretID,omitempty"`
|
||||||
AppRole string `json:"appRole,omitempty"`
|
AppRole string `json:"appRole,omitempty"`
|
||||||
|
@ -42,207 +44,12 @@ type VaultCAS struct {
|
||||||
fingerprint string
|
fingerprint string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Certificate struct {
|
type certBundle struct {
|
||||||
leaf *x509.Certificate
|
leaf *x509.Certificate
|
||||||
intermediates []*x509.Certificate
|
intermediates []*x509.Certificate
|
||||||
root *x509.Certificate
|
root *x509.Certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadOptions(config json.RawMessage) (*VaultOptions, error) {
|
|
||||||
var vc *VaultOptions
|
|
||||||
|
|
||||||
err := json.Unmarshal(config, &vc)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error decoding vaultCAS config")
|
|
||||||
}
|
|
||||||
|
|
||||||
if vc.PKI == "" {
|
|
||||||
vc.PKI = "pki" // use default pki vault name
|
|
||||||
}
|
|
||||||
|
|
||||||
if vc.PKIRoleDefault == "" {
|
|
||||||
vc.PKIRoleDefault = "default" // use default pki role name
|
|
||||||
}
|
|
||||||
|
|
||||||
if vc.PKIRoleRSA == "" {
|
|
||||||
vc.PKIRoleRSA = vc.PKIRoleDefault
|
|
||||||
}
|
|
||||||
if vc.PKIRoleEC == "" {
|
|
||||||
vc.PKIRoleEC = vc.PKIRoleDefault
|
|
||||||
}
|
|
||||||
if vc.PKIRoleEd25519 == "" {
|
|
||||||
vc.PKIRoleEd25519 = vc.PKIRoleDefault
|
|
||||||
}
|
|
||||||
|
|
||||||
if vc.RoleID == "" {
|
|
||||||
return nil, errors.New("vaultCAS config options must define `roleID`")
|
|
||||||
}
|
|
||||||
|
|
||||||
if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" {
|
|
||||||
return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`")
|
|
||||||
}
|
|
||||||
|
|
||||||
if vc.PKI == "" {
|
|
||||||
vc.PKI = "pki" // use default pki vault name
|
|
||||||
}
|
|
||||||
|
|
||||||
if vc.AppRole == "" {
|
|
||||||
vc.AppRole = "auth/approle"
|
|
||||||
}
|
|
||||||
|
|
||||||
return vc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func certificateSort(n []*x509.Certificate) bool {
|
|
||||||
// sort all cert using bubble sort
|
|
||||||
isSorted := false
|
|
||||||
s := 0
|
|
||||||
maxSwap := len(n) * (len(n) - 1) / 2
|
|
||||||
for s <= maxSwap && !isSorted {
|
|
||||||
isSorted = true
|
|
||||||
var i = 0
|
|
||||||
for i < len(n)-1 {
|
|
||||||
if !isSignedBy(n[i], n[i+1]) {
|
|
||||||
// swap
|
|
||||||
n[i], n[i+1] = n[i+1], n[i]
|
|
||||||
isSorted = false
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
s++
|
|
||||||
}
|
|
||||||
return isSorted
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSignedBy(i, j *x509.Certificate) bool {
|
|
||||||
signer := x509.NewCertPool()
|
|
||||||
signer.AddCert(j)
|
|
||||||
|
|
||||||
opts := x509.VerifyOptions{
|
|
||||||
Roots: signer,
|
|
||||||
Intermediates: x509.NewCertPool(), // set empty to avoid using system CA
|
|
||||||
}
|
|
||||||
_, err := i.Verify(opts)
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseCertificates(pemCert string) []*x509.Certificate {
|
|
||||||
var certs []*x509.Certificate
|
|
||||||
rest := []byte(pemCert)
|
|
||||||
var block *pem.Block
|
|
||||||
for {
|
|
||||||
block, rest = pem.Decode(rest)
|
|
||||||
if block == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
cert, err := x509.ParseCertificate(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
certs = append(certs, cert)
|
|
||||||
}
|
|
||||||
return certs
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCertificateAndChain(certb certutil.CertBundle) (*Certificate, error) {
|
|
||||||
// certutil.CertBundle contains CAChain and Certificate.
|
|
||||||
// Both could have a common part or different and we are not sure
|
|
||||||
// how user define their chain inside vault.
|
|
||||||
// We will create an array of certificate with all parsed certificates
|
|
||||||
// then sort the array to create a consistent chain
|
|
||||||
var root *x509.Certificate
|
|
||||||
var leaf *x509.Certificate
|
|
||||||
intermediates := make([]*x509.Certificate, 0)
|
|
||||||
used := make(map[string]bool) // ensure that intermediate are uniq
|
|
||||||
for _, chain := range append(certb.CAChain, certb.Certificate) {
|
|
||||||
for _, cert := range parseCertificates(chain) {
|
|
||||||
if used[cert.SerialNumber.String()] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
used[cert.SerialNumber.String()] = true
|
|
||||||
switch {
|
|
||||||
case isRoot(cert):
|
|
||||||
root = cert
|
|
||||||
case cert.BasicConstraintsValid && cert.IsCA:
|
|
||||||
intermediates = append(intermediates, cert)
|
|
||||||
default:
|
|
||||||
leaf = cert
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ok := certificateSort(intermediates); !ok {
|
|
||||||
return nil, errors.Errorf("failed to sort certificate, probably one of cert is not part of the chain")
|
|
||||||
}
|
|
||||||
|
|
||||||
certificate := &Certificate{
|
|
||||||
root: root,
|
|
||||||
leaf: leaf,
|
|
||||||
intermediates: intermediates,
|
|
||||||
}
|
|
||||||
|
|
||||||
return certificate, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseCertificateRequest(pemCsr string) (*x509.CertificateRequest, error) {
|
|
||||||
block, _ := pem.Decode([]byte(pemCsr))
|
|
||||||
if block == nil {
|
|
||||||
return nil, errors.Errorf("error decoding certificate request: not a valid PEM encoded block, please verify\r\n%v", pemCsr)
|
|
||||||
}
|
|
||||||
|
|
||||||
cr, err := x509.ParseCertificateRequest(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error parsing certificate request")
|
|
||||||
}
|
|
||||||
return cr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) {
|
|
||||||
var vaultPKIRole string
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case cr.PublicKeyAlgorithm == x509.RSA:
|
|
||||||
vaultPKIRole = v.config.PKIRoleRSA
|
|
||||||
case cr.PublicKeyAlgorithm == x509.ECDSA:
|
|
||||||
vaultPKIRole = v.config.PKIRoleEC
|
|
||||||
case cr.PublicKeyAlgorithm == x509.Ed25519:
|
|
||||||
vaultPKIRole = v.config.PKIRoleEd25519
|
|
||||||
default:
|
|
||||||
return nil, nil, errors.Errorf("createCertificate: Unsupported public key algorithm '%v'", cr.PublicKeyAlgorithm)
|
|
||||||
}
|
|
||||||
|
|
||||||
certPemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: cr.Raw})
|
|
||||||
if certPemBytes == nil {
|
|
||||||
return nil, nil, errors.Errorf("createCertificate: Failed to encode pem '%v'", cr.Raw)
|
|
||||||
}
|
|
||||||
|
|
||||||
y := map[string]interface{}{
|
|
||||||
"csr": string(certPemBytes),
|
|
||||||
"format": "pem_bundle",
|
|
||||||
"ttl": lifetime.Seconds(),
|
|
||||||
}
|
|
||||||
|
|
||||||
secret, err := v.client.Logical().Write(v.config.PKI+"/sign/"+vaultPKIRole, y)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, errors.Wrapf(err, "createCertificate: unable to sign certificate %v", y)
|
|
||||||
}
|
|
||||||
if secret == nil {
|
|
||||||
return nil, nil, errors.New("createCertificate: secret sign is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
var certBundle certutil.CertBundle
|
|
||||||
if err := unmarshalMap(secret.Data, &certBundle); err != nil {
|
|
||||||
return nil, nil, errors.Wrap(err, "error unmarshaling cert bundle")
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := getCertificateAndChain(certBundle)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return certificate and certificate chain
|
|
||||||
return cert.leaf, cert.intermediates, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new CertificateAuthorityService implementation
|
// New creates a new CertificateAuthorityService implementation
|
||||||
// using Hashicorp Vault
|
// using Hashicorp Vault
|
||||||
func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) {
|
func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) {
|
||||||
|
@ -305,9 +112,9 @@ func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) {
|
||||||
func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) {
|
func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) {
|
||||||
switch {
|
switch {
|
||||||
case req.CSR == nil:
|
case req.CSR == nil:
|
||||||
return nil, errors.New("CreateCertificate: `CSR` cannot be nil")
|
return nil, errors.New("createCertificate `csr` cannot be nil")
|
||||||
case req.Lifetime == 0:
|
case req.Lifetime == 0:
|
||||||
return nil, errors.New("CreateCertificate: `LIFETIME` cannot be 0")
|
return nil, errors.New("createCertificate `lifetime` cannot be 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
cert, chain, err := v.createCertificate(req.CSR, req.Lifetime)
|
cert, chain, err := v.createCertificate(req.CSR, req.Lifetime)
|
||||||
|
@ -324,26 +131,28 @@ func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv
|
||||||
func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) {
|
func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) {
|
||||||
secret, err := v.client.Logical().Read(v.config.PKI + "/cert/ca_chain")
|
secret, err := v.client.Logical().Read(v.config.PKI + "/cert/ca_chain")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "unable to read root")
|
return nil, errors.Wrap(err, "error reading ca chain")
|
||||||
}
|
}
|
||||||
if secret == nil {
|
if secret == nil {
|
||||||
return nil, errors.New("secret root is empty")
|
return nil, errors.New("error reading ca chain: response is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
var certBundle certutil.CertBundle
|
chain, ok := secret.Data["certificate"].(string)
|
||||||
if err := unmarshalMap(secret.Data, &certBundle); err != nil {
|
if !ok {
|
||||||
return nil, errors.Wrap(err, "error unmarshaling cert bundle")
|
return nil, errors.New("error unmarshaling vault response: certificate not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
cert, err := getCertificateAndChain(certBundle)
|
cert, err := getCertificateBundle(chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if cert.root == nil {
|
||||||
|
return nil, errors.New("error unmarshaling vault response: root certificate not found")
|
||||||
|
}
|
||||||
|
|
||||||
sha256Sum := sha256.Sum256(cert.root.Raw)
|
sum := sha256.Sum256(cert.root.Raw)
|
||||||
expectedSum := certutil.GetHexFormatted(sha256Sum[:], "")
|
if !strings.EqualFold(v.fingerprint, strings.ToLower(hex.EncodeToString(sum[:]))) {
|
||||||
if expectedSum != v.fingerprint {
|
return nil, errors.New("error verifying vault root: fingerprint does not match")
|
||||||
return nil, errors.Errorf("Vault Root CA fingerprint `%s` doesn't match config fingerprint `%v`", expectedSum, v.fingerprint)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &apiv1.GetCertificateAuthorityResponse{
|
return &apiv1.GetCertificateAuthorityResponse{
|
||||||
|
@ -357,34 +166,28 @@ func (v *VaultCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.
|
||||||
return nil, apiv1.ErrNotImplemented{Message: "vaultCAS does not support renewals"}
|
return nil, apiv1.ErrNotImplemented{Message: "vaultCAS does not support renewals"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RevokeCertificate revokes a certificate by serial number.
|
||||||
func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
|
func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
|
||||||
if req.SerialNumber == "" && req.Certificate == nil {
|
if req.SerialNumber == "" && req.Certificate == nil {
|
||||||
return nil, errors.New("`serialNumber` or `certificate` are required")
|
return nil, errors.New("revokeCertificate `serialNumber` or `certificate` are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
var serialNumber []byte
|
var sn *big.Int
|
||||||
if req.SerialNumber != "" {
|
if req.SerialNumber != "" {
|
||||||
// req.SerialNumber is a big.Int string representation
|
var ok bool
|
||||||
n := new(big.Int)
|
if sn, ok = new(big.Int).SetString(req.SerialNumber, 10); !ok {
|
||||||
n, ok := n.SetString(req.SerialNumber, 10)
|
return nil, errors.Errorf("error parsing serialNumber: %v cannot be converted to big.Int", req.SerialNumber)
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("serialNumber `%v` can't be convert to big.Int", req.SerialNumber)
|
|
||||||
}
|
}
|
||||||
serialNumber = n.Bytes()
|
|
||||||
} else {
|
} else {
|
||||||
// req.Certificate.SerialNumber is a big.Int
|
sn = req.Certificate.SerialNumber
|
||||||
serialNumber = req.Certificate.SerialNumber.Bytes()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
serialNumberDash := certutil.GetHexFormatted(serialNumber, "-")
|
vaultReq := map[string]interface{}{
|
||||||
|
"serial_number": formatSerialNumber(sn),
|
||||||
y := map[string]interface{}{
|
|
||||||
"serial_number": serialNumberDash,
|
|
||||||
}
|
}
|
||||||
|
_, err := v.client.Logical().Write(v.config.PKI+"/revoke/", vaultReq)
|
||||||
_, err := v.client.Logical().Write(v.config.PKI+"/revoke/", y)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "unable to revoke certificate")
|
return nil, errors.Wrap(err, "error revoking certificate")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &apiv1.RevokeCertificateResponse{
|
return &apiv1.RevokeCertificateResponse{
|
||||||
|
@ -393,13 +196,136 @@ func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func unmarshalMap(m map[string]interface{}, v interface{}) error {
|
func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) {
|
||||||
b, err := json.Marshal(m)
|
var vaultPKIRole string
|
||||||
if err != nil {
|
|
||||||
return err
|
switch {
|
||||||
|
case cr.PublicKeyAlgorithm == x509.RSA:
|
||||||
|
vaultPKIRole = v.config.PKIRoleRSA
|
||||||
|
case cr.PublicKeyAlgorithm == x509.ECDSA:
|
||||||
|
vaultPKIRole = v.config.PKIRoleEC
|
||||||
|
case cr.PublicKeyAlgorithm == x509.Ed25519:
|
||||||
|
vaultPKIRole = v.config.PKIRoleEd25519
|
||||||
|
default:
|
||||||
|
return nil, nil, errors.Errorf("unsupported public key algorithm '%v'", cr.PublicKeyAlgorithm)
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.Unmarshal(b, v)
|
vaultReq := map[string]interface{}{
|
||||||
|
"csr": string(pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE REQUEST",
|
||||||
|
Bytes: cr.Raw,
|
||||||
|
})),
|
||||||
|
"format": "pem_bundle",
|
||||||
|
"ttl": lifetime.Seconds(),
|
||||||
|
}
|
||||||
|
|
||||||
|
secret, err := v.client.Logical().Write(v.config.PKI+"/sign/"+vaultPKIRole, vaultReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "error signing certificate")
|
||||||
|
}
|
||||||
|
if secret == nil {
|
||||||
|
return nil, nil, errors.New("error signing certificate: response is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
chain, ok := secret.Data["certificate"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, errors.New("error unmarshaling vault response: certificate not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := getCertificateBundle(chain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return certificate and certificate chain
|
||||||
|
return cert.leaf, cert.intermediates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadOptions(config json.RawMessage) (*VaultOptions, error) {
|
||||||
|
var vc *VaultOptions
|
||||||
|
|
||||||
|
err := json.Unmarshal(config, &vc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error decoding vaultCAS config")
|
||||||
|
}
|
||||||
|
|
||||||
|
if vc.PKI == "" {
|
||||||
|
vc.PKI = "pki" // use default pki vault name
|
||||||
|
}
|
||||||
|
|
||||||
|
if vc.PKIRoleDefault == "" {
|
||||||
|
vc.PKIRoleDefault = "default" // use default pki role name
|
||||||
|
}
|
||||||
|
|
||||||
|
if vc.PKIRoleRSA == "" {
|
||||||
|
vc.PKIRoleRSA = vc.PKIRoleDefault
|
||||||
|
}
|
||||||
|
if vc.PKIRoleEC == "" {
|
||||||
|
vc.PKIRoleEC = vc.PKIRoleDefault
|
||||||
|
}
|
||||||
|
if vc.PKIRoleEd25519 == "" {
|
||||||
|
vc.PKIRoleEd25519 = vc.PKIRoleDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
if vc.RoleID == "" {
|
||||||
|
return nil, errors.New("vaultCAS config options must define `roleID`")
|
||||||
|
}
|
||||||
|
|
||||||
|
if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" {
|
||||||
|
return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`")
|
||||||
|
}
|
||||||
|
|
||||||
|
if vc.PKI == "" {
|
||||||
|
vc.PKI = "pki" // use default pki vault name
|
||||||
|
}
|
||||||
|
|
||||||
|
if vc.AppRole == "" {
|
||||||
|
vc.AppRole = "auth/approle"
|
||||||
|
}
|
||||||
|
|
||||||
|
return vc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCertificates(pemCert string) []*x509.Certificate {
|
||||||
|
var certs []*x509.Certificate
|
||||||
|
rest := []byte(pemCert)
|
||||||
|
var block *pem.Block
|
||||||
|
for {
|
||||||
|
block, rest = pem.Decode(rest)
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
certs = append(certs, cert)
|
||||||
|
}
|
||||||
|
return certs
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCertificateBundle(chain string) (*certBundle, error) {
|
||||||
|
var root *x509.Certificate
|
||||||
|
var leaf *x509.Certificate
|
||||||
|
var intermediates []*x509.Certificate
|
||||||
|
for _, cert := range parseCertificates(chain) {
|
||||||
|
switch {
|
||||||
|
case isRoot(cert):
|
||||||
|
root = cert
|
||||||
|
case cert.BasicConstraintsValid && cert.IsCA:
|
||||||
|
intermediates = append(intermediates, cert)
|
||||||
|
default:
|
||||||
|
leaf = cert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
certificate := &certBundle{
|
||||||
|
root: root,
|
||||||
|
leaf: leaf,
|
||||||
|
intermediates: intermediates,
|
||||||
|
}
|
||||||
|
|
||||||
|
return certificate, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isRoot returns true if the given certificate is a root certificate.
|
// isRoot returns true if the given certificate is a root certificate.
|
||||||
|
@ -409,3 +335,16 @@ func isRoot(cert *x509.Certificate) bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// formatSerialNumber formats a serial number to a dash-separated hexadecimal
|
||||||
|
// string.
|
||||||
|
func formatSerialNumber(sn *big.Int) string {
|
||||||
|
var ret bytes.Buffer
|
||||||
|
for _, b := range sn.Bytes() {
|
||||||
|
if ret.Len() > 0 {
|
||||||
|
ret.WriteString("-")
|
||||||
|
}
|
||||||
|
ret.WriteString(hex.EncodeToString([]byte{b}))
|
||||||
|
}
|
||||||
|
return ret.String()
|
||||||
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
vault "github.com/hashicorp/vault/api"
|
vault "github.com/hashicorp/vault/api"
|
||||||
auth "github.com/hashicorp/vault/api/auth/approle"
|
auth "github.com/hashicorp/vault/api/auth/approle"
|
||||||
"github.com/smallstep/certificates/cas/apiv1"
|
"github.com/smallstep/certificates/cas/apiv1"
|
||||||
|
"go.step.sm/crypto/pemutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -80,13 +81,13 @@ func mustParseCertificate(t *testing.T, pemCert string) *x509.Certificate {
|
||||||
return crt
|
return crt
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustParseCertificateRequest(t *testing.T, pemCert string) *x509.CertificateRequest {
|
func mustParseCertificateRequest(t *testing.T, pemData string) *x509.CertificateRequest {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
crt, err := parseCertificateRequest(pemCert)
|
csr, err := pemutil.ParseCertificateRequest([]byte(pemData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
return crt
|
return csr
|
||||||
}
|
}
|
||||||
|
|
||||||
func testCAHelper(t *testing.T) (*url.URL, *vault.Client) {
|
func testCAHelper(t *testing.T) (*url.URL, *vault.Client) {
|
||||||
|
@ -107,17 +108,17 @@ func testCAHelper(t *testing.T) (*url.URL, *vault.Client) {
|
||||||
}`)
|
}`)
|
||||||
case r.RequestURI == "/v1/pki/sign/ec":
|
case r.RequestURI == "/v1/pki/sign/ec":
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned}}
|
cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}}
|
||||||
writeJSON(w, cert)
|
writeJSON(w, cert)
|
||||||
return
|
return
|
||||||
case r.RequestURI == "/v1/pki/sign/rsa":
|
case r.RequestURI == "/v1/pki/sign/rsa":
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned}}
|
cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}}
|
||||||
writeJSON(w, cert)
|
writeJSON(w, cert)
|
||||||
return
|
return
|
||||||
case r.RequestURI == "/v1/pki/sign/ed25519":
|
case r.RequestURI == "/v1/pki/sign/ed25519":
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned}}
|
cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}}
|
||||||
writeJSON(w, cert)
|
writeJSON(w, cert)
|
||||||
return
|
return
|
||||||
case r.RequestURI == "/v1/pki/cert/ca_chain":
|
case r.RequestURI == "/v1/pki/cert/ca_chain":
|
||||||
|
@ -232,21 +233,21 @@ func TestVaultCAS_CreateCertificate(t *testing.T) {
|
||||||
Lifetime: time.Hour,
|
Lifetime: time.Hour,
|
||||||
}}, &apiv1.CreateCertificateResponse{
|
}}, &apiv1.CreateCertificateResponse{
|
||||||
Certificate: mustParseCertificate(t, testCertificateSigned),
|
Certificate: mustParseCertificate(t, testCertificateSigned),
|
||||||
CertificateChain: []*x509.Certificate{},
|
CertificateChain: nil,
|
||||||
}, false},
|
}, false},
|
||||||
{"ok rsa", fields{client, options}, args{&apiv1.CreateCertificateRequest{
|
{"ok rsa", fields{client, options}, args{&apiv1.CreateCertificateRequest{
|
||||||
CSR: mustParseCertificateRequest(t, testCertificateCsrRsa),
|
CSR: mustParseCertificateRequest(t, testCertificateCsrRsa),
|
||||||
Lifetime: time.Hour,
|
Lifetime: time.Hour,
|
||||||
}}, &apiv1.CreateCertificateResponse{
|
}}, &apiv1.CreateCertificateResponse{
|
||||||
Certificate: mustParseCertificate(t, testCertificateSigned),
|
Certificate: mustParseCertificate(t, testCertificateSigned),
|
||||||
CertificateChain: []*x509.Certificate{},
|
CertificateChain: nil,
|
||||||
}, false},
|
}, false},
|
||||||
{"ok ed25519", fields{client, options}, args{&apiv1.CreateCertificateRequest{
|
{"ok ed25519", fields{client, options}, args{&apiv1.CreateCertificateRequest{
|
||||||
CSR: mustParseCertificateRequest(t, testCertificateCsrEd25519),
|
CSR: mustParseCertificateRequest(t, testCertificateCsrEd25519),
|
||||||
Lifetime: time.Hour,
|
Lifetime: time.Hour,
|
||||||
}}, &apiv1.CreateCertificateResponse{
|
}}, &apiv1.CreateCertificateResponse{
|
||||||
Certificate: mustParseCertificate(t, testCertificateSigned),
|
Certificate: mustParseCertificate(t, testCertificateSigned),
|
||||||
CertificateChain: []*x509.Certificate{},
|
CertificateChain: nil,
|
||||||
}, false},
|
}, false},
|
||||||
{"fail CSR", fields{client, options}, args{&apiv1.CreateCertificateRequest{
|
{"fail CSR", fields{client, options}, args{&apiv1.CreateCertificateRequest{
|
||||||
CSR: nil,
|
CSR: nil,
|
||||||
|
|
Loading…
Reference in a new issue