certificates/kms/azurekms/signer.go
2021-10-26 17:57:59 -07:00

182 lines
4.7 KiB
Go

package azurekms
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"encoding/base64"
"io"
"math/big"
"time"
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/pkg/errors"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
// Signer implements a crypto.Signer using the AWS KMS.
type Signer struct {
client KeyVaultClient
vaultBaseURL string
name string
version string
publicKey crypto.PublicKey
}
// NewSigner creates a new signer using a key in the AWS KMS.
func NewSigner(client KeyVaultClient, signingKey string, defaults DefaultOptions) (crypto.Signer, error) {
vault, name, version, _, err := parseKeyName(signingKey, defaults)
if err != nil {
return nil, err
}
// Make sure that the key exists.
signer := &Signer{
client: client,
vaultBaseURL: vaultBaseURL(vault),
name: name,
version: version,
}
if err := signer.preloadKey(); err != nil {
return nil, err
}
return signer, nil
}
func (s *Signer) preloadKey() error {
ctx, cancel := defaultContext()
defer cancel()
resp, err := s.client.GetKey(ctx, s.vaultBaseURL, s.name, s.version)
if err != nil {
return errors.Wrap(err, "keyVault GetKey failed")
}
s.publicKey, err = convertKey(resp.Key)
return err
}
// Public returns the public key of this signer or an error.
func (s *Signer) Public() crypto.PublicKey {
return s.publicKey
}
// Sign signs digest with the private key stored in the AWS KMS.
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
alg, err := getSigningAlgorithm(s.Public(), opts)
if err != nil {
return nil, err
}
b64 := base64.RawURLEncoding.EncodeToString(digest)
// Sign with retry if the key is not ready
resp, err := s.signWithRetry(alg, b64, 3)
if err != nil {
return nil, errors.Wrap(err, "keyVault Sign failed")
}
sig, err := base64.RawURLEncoding.DecodeString(*resp.Result)
if err != nil {
return nil, errors.Wrap(err, "error decoding keyVault Sign result")
}
var octetSize int
switch alg {
case keyvault.ES256:
octetSize = 32 // 256-bit, concat(R,S) = 64 bytes
case keyvault.ES384:
octetSize = 48 // 384-bit, concat(R,S) = 96 bytes
case keyvault.ES512:
octetSize = 66 // 528-bit, concat(R,S) = 132 bytes
default:
return sig, nil
}
// Convert to asn1
if len(sig) != octetSize*2 {
return nil, errors.Errorf("keyVault Sign failed: unexpected signature length")
}
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1BigInt(new(big.Int).SetBytes(sig[:octetSize])) // R
b.AddASN1BigInt(new(big.Int).SetBytes(sig[octetSize:])) // S
})
return b.Bytes()
}
func (s *Signer) signWithRetry(alg keyvault.JSONWebKeySignatureAlgorithm, b64 string, retryAttempts int) (keyvault.KeyOperationResult, error) {
retry:
ctx, cancel := defaultContext()
defer cancel()
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{
Algorithm: alg,
Value: &b64,
})
if err != nil && retryAttempts > 0 {
var requestError *azure.RequestError
if errors.As(err, &requestError) {
if se := requestError.ServiceError; se != nil && se.InnerError != nil {
code, ok := se.InnerError["code"].(string)
if ok && code == "KeyNotYetValid" {
time.Sleep(time.Second / time.Duration(retryAttempts))
retryAttempts--
goto retry
}
}
}
}
return resp, err
}
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) {
switch key.(type) {
case *rsa.PublicKey:
hashFunc := opts.HashFunc()
pss, isPSS := opts.(*rsa.PSSOptions)
// Random salt lengths are not supported
if isPSS &&
pss.SaltLength != rsa.PSSSaltLengthAuto &&
pss.SaltLength != rsa.PSSSaltLengthEqualsHash &&
pss.SaltLength != hashFunc.Size() {
return "", errors.Errorf("unsupported RSA-PSS salt length %d", pss.SaltLength)
}
switch h := hashFunc; h {
case crypto.SHA256:
if isPSS {
return keyvault.PS256, nil
}
return keyvault.RS256, nil
case crypto.SHA384:
if isPSS {
return keyvault.PS384, nil
}
return keyvault.RS384, nil
case crypto.SHA512:
if isPSS {
return keyvault.PS512, nil
}
return keyvault.RS512, nil
default:
return "", errors.Errorf("unsupported hash function %v", h)
}
case *ecdsa.PublicKey:
switch h := opts.HashFunc(); h {
case crypto.SHA256:
return keyvault.ES256, nil
case crypto.SHA384:
return keyvault.ES384, nil
case crypto.SHA512:
return keyvault.ES512, nil
default:
return "", errors.Errorf("unsupported hash function %v", h)
}
default:
return "", errors.Errorf("unsupported key type %T", key)
}
}