182 lines
4.7 KiB
Go
182 lines
4.7 KiB
Go
package azurekms
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"io"
|
|
"math/big"
|
|
"time"
|
|
|
|
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
|
"github.com/Azure/go-autorest/autorest/azure"
|
|
"github.com/pkg/errors"
|
|
"golang.org/x/crypto/cryptobyte"
|
|
"golang.org/x/crypto/cryptobyte/asn1"
|
|
)
|
|
|
|
// Signer implements a crypto.Signer using the AWS KMS.
|
|
type Signer struct {
|
|
client KeyVaultClient
|
|
vaultBaseURL string
|
|
name string
|
|
version string
|
|
publicKey crypto.PublicKey
|
|
}
|
|
|
|
// NewSigner creates a new signer using a key in the AWS KMS.
|
|
func NewSigner(client KeyVaultClient, signingKey string, defaults DefaultOptions) (crypto.Signer, error) {
|
|
vault, name, version, _, err := parseKeyName(signingKey, defaults)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Make sure that the key exists.
|
|
signer := &Signer{
|
|
client: client,
|
|
vaultBaseURL: vaultBaseURL(vault),
|
|
name: name,
|
|
version: version,
|
|
}
|
|
if err := signer.preloadKey(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return signer, nil
|
|
}
|
|
|
|
func (s *Signer) preloadKey() error {
|
|
ctx, cancel := defaultContext()
|
|
defer cancel()
|
|
|
|
resp, err := s.client.GetKey(ctx, s.vaultBaseURL, s.name, s.version)
|
|
if err != nil {
|
|
return errors.Wrap(err, "keyVault GetKey failed")
|
|
}
|
|
|
|
s.publicKey, err = convertKey(resp.Key)
|
|
return err
|
|
}
|
|
|
|
// Public returns the public key of this signer or an error.
|
|
func (s *Signer) Public() crypto.PublicKey {
|
|
return s.publicKey
|
|
}
|
|
|
|
// Sign signs digest with the private key stored in the AWS KMS.
|
|
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
|
alg, err := getSigningAlgorithm(s.Public(), opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
b64 := base64.RawURLEncoding.EncodeToString(digest)
|
|
|
|
// Sign with retry if the key is not ready
|
|
resp, err := s.signWithRetry(alg, b64, 3)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "keyVault Sign failed")
|
|
}
|
|
|
|
sig, err := base64.RawURLEncoding.DecodeString(*resp.Result)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "error decoding keyVault Sign result")
|
|
}
|
|
|
|
var octetSize int
|
|
switch alg {
|
|
case keyvault.ES256:
|
|
octetSize = 32 // 256-bit, concat(R,S) = 64 bytes
|
|
case keyvault.ES384:
|
|
octetSize = 48 // 384-bit, concat(R,S) = 96 bytes
|
|
case keyvault.ES512:
|
|
octetSize = 66 // 528-bit, concat(R,S) = 132 bytes
|
|
default:
|
|
return sig, nil
|
|
}
|
|
|
|
// Convert to asn1
|
|
if len(sig) != octetSize*2 {
|
|
return nil, errors.Errorf("keyVault Sign failed: unexpected signature length")
|
|
}
|
|
var b cryptobyte.Builder
|
|
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
|
b.AddASN1BigInt(new(big.Int).SetBytes(sig[:octetSize])) // R
|
|
b.AddASN1BigInt(new(big.Int).SetBytes(sig[octetSize:])) // S
|
|
})
|
|
return b.Bytes()
|
|
}
|
|
|
|
func (s *Signer) signWithRetry(alg keyvault.JSONWebKeySignatureAlgorithm, b64 string, retryAttempts int) (keyvault.KeyOperationResult, error) {
|
|
retry:
|
|
ctx, cancel := defaultContext()
|
|
defer cancel()
|
|
|
|
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{
|
|
Algorithm: alg,
|
|
Value: &b64,
|
|
})
|
|
if err != nil && retryAttempts > 0 {
|
|
var requestError *azure.RequestError
|
|
if errors.As(err, &requestError) {
|
|
if se := requestError.ServiceError; se != nil && se.InnerError != nil {
|
|
code, ok := se.InnerError["code"].(string)
|
|
if ok && code == "KeyNotYetValid" {
|
|
time.Sleep(time.Second / time.Duration(retryAttempts))
|
|
retryAttempts--
|
|
goto retry
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return resp, err
|
|
}
|
|
|
|
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) {
|
|
switch key.(type) {
|
|
case *rsa.PublicKey:
|
|
hashFunc := opts.HashFunc()
|
|
pss, isPSS := opts.(*rsa.PSSOptions)
|
|
// Random salt lengths are not supported
|
|
if isPSS &&
|
|
pss.SaltLength != rsa.PSSSaltLengthAuto &&
|
|
pss.SaltLength != rsa.PSSSaltLengthEqualsHash &&
|
|
pss.SaltLength != hashFunc.Size() {
|
|
return "", errors.Errorf("unsupported RSA-PSS salt length %d", pss.SaltLength)
|
|
}
|
|
|
|
switch h := hashFunc; h {
|
|
case crypto.SHA256:
|
|
if isPSS {
|
|
return keyvault.PS256, nil
|
|
}
|
|
return keyvault.RS256, nil
|
|
case crypto.SHA384:
|
|
if isPSS {
|
|
return keyvault.PS384, nil
|
|
}
|
|
return keyvault.RS384, nil
|
|
case crypto.SHA512:
|
|
if isPSS {
|
|
return keyvault.PS512, nil
|
|
}
|
|
return keyvault.RS512, nil
|
|
default:
|
|
return "", errors.Errorf("unsupported hash function %v", h)
|
|
}
|
|
case *ecdsa.PublicKey:
|
|
switch h := opts.HashFunc(); h {
|
|
case crypto.SHA256:
|
|
return keyvault.ES256, nil
|
|
case crypto.SHA384:
|
|
return keyvault.ES384, nil
|
|
case crypto.SHA512:
|
|
return keyvault.ES512, nil
|
|
default:
|
|
return "", errors.Errorf("unsupported hash function %v", h)
|
|
}
|
|
default:
|
|
return "", errors.Errorf("unsupported key type %T", key)
|
|
}
|
|
}
|