Remove kms package

This commit is contained in:
Mariano Cano 2022-08-08 18:01:10 -07:00
parent 369b8f81c3
commit 4985ab1d62
57 changed files with 0 additions and 9040 deletions

View file

@ -1,137 +0,0 @@
package apiv1
import (
"crypto"
"crypto/x509"
"strings"
"github.com/pkg/errors"
)
// KeyManager is the interface implemented by all the KMS.
type KeyManager interface {
GetPublicKey(req *GetPublicKeyRequest) (crypto.PublicKey, error)
CreateKey(req *CreateKeyRequest) (*CreateKeyResponse, error)
CreateSigner(req *CreateSignerRequest) (crypto.Signer, error)
Close() error
}
// Decrypter is an interface implemented by KMSes that are used
// in operations that require decryption
type Decrypter interface {
CreateDecrypter(req *CreateDecrypterRequest) (crypto.Decrypter, error)
}
// CertificateManager is the interface implemented by the KMS that can load and
// store x509.Certificates.
type CertificateManager interface {
LoadCertificate(req *LoadCertificateRequest) (*x509.Certificate, error)
StoreCertificate(req *StoreCertificateRequest) error
}
// ValidateName is an interface that KeyManager can implement to validate a
// given name or URI.
type NameValidator interface {
ValidateName(s string) error
}
// ErrNotImplemented is the type of error returned if an operation is not
// implemented.
type ErrNotImplemented struct {
Message string
}
func (e ErrNotImplemented) Error() string {
if e.Message != "" {
return e.Message
}
return "not implemented"
}
// ErrAlreadyExists is the type of error returned if a key already exists. This
// is currently only implmented on pkcs11.
type ErrAlreadyExists struct {
Message string
}
func (e ErrAlreadyExists) Error() string {
if e.Message != "" {
return e.Message
}
return "key already exists"
}
// Type represents the KMS type used.
type Type string
const (
// DefaultKMS is a KMS implementation using software.
DefaultKMS Type = ""
// SoftKMS is a KMS implementation using software.
SoftKMS Type = "softkms"
// CloudKMS is a KMS implementation using Google's Cloud KMS.
CloudKMS Type = "cloudkms"
// AmazonKMS is a KMS implementation using Amazon AWS KMS.
AmazonKMS Type = "awskms"
// PKCS11 is a KMS implementation using the PKCS11 standard.
PKCS11 Type = "pkcs11"
// YubiKey is a KMS implementation using a YubiKey PIV.
YubiKey Type = "yubikey"
// SSHAgentKMS is a KMS implementation using ssh-agent to access keys.
SSHAgentKMS Type = "sshagentkms"
// AzureKMS is a KMS implementation using Azure Key Vault.
AzureKMS Type = "azurekms"
)
// Options are the KMS options. They represent the kms object in the ca.json.
type Options struct {
// The type of the KMS to use.
Type string `json:"type"`
// Path to the credentials file used in CloudKMS and AmazonKMS.
CredentialsFile string `json:"credentialsFile,omitempty"`
// URI is based on the PKCS #11 URI Scheme defined in
// https://tools.ietf.org/html/rfc7512 and represents the configuration used
// to connect to the KMS.
//
// Used by: pkcs11
URI string `json:"uri,omitempty"`
// Pin used to access the PKCS11 module. It can be defined in the URI using
// the pin-value or pin-source properties.
Pin string `json:"pin,omitempty"`
// ManagementKey used in YubiKeys. Default management key is the hexadecimal
// string 010203040506070801020304050607080102030405060708:
// []byte{
// 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
// 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
// 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
// }
ManagementKey string `json:"managementKey,omitempty"`
// Region to use in AmazonKMS.
Region string `json:"region,omitempty"`
// Profile to use in AmazonKMS.
Profile string `json:"profile,omitempty"`
}
// Validate checks the fields in Options.
func (o *Options) Validate() error {
if o == nil {
return nil
}
switch Type(strings.ToLower(o.Type)) {
case DefaultKMS, SoftKMS: // Go crypto based kms.
case CloudKMS, AmazonKMS, AzureKMS: // Cloud based kms.
case YubiKey, PKCS11: // Hardware based kms.
case SSHAgentKMS: // Others
default:
return errors.Errorf("unsupported kms type %s", o.Type)
}
return nil
}

View file

@ -1,76 +0,0 @@
package apiv1
import (
"testing"
)
func TestOptions_Validate(t *testing.T) {
tests := []struct {
name string
options *Options
wantErr bool
}{
{"nil", nil, false},
{"softkms", &Options{Type: "softkms"}, false},
{"cloudkms", &Options{Type: "cloudkms"}, false},
{"awskms", &Options{Type: "awskms"}, false},
{"sshagentkms", &Options{Type: "sshagentkms"}, false},
{"pkcs11", &Options{Type: "pkcs11"}, false},
{"unsupported", &Options{Type: "unsupported"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.options.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestErrNotImplemented_Error(t *testing.T) {
type fields struct {
msg string
}
tests := []struct {
name string
fields fields
want string
}{
{"default", fields{}, "not implemented"},
{"custom", fields{"custom message: not implemented"}, "custom message: not implemented"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := ErrNotImplemented{
Message: tt.fields.msg,
}
if got := e.Error(); got != tt.want {
t.Errorf("ErrNotImplemented.Error() = %v, want %v", got, tt.want)
}
})
}
}
func TestErrAlreadyExists_Error(t *testing.T) {
type fields struct {
msg string
}
tests := []struct {
name string
fields fields
want string
}{
{"default", fields{}, "key already exists"},
{"custom", fields{"custom message: key already exists"}, "custom message: key already exists"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := ErrAlreadyExists{
Message: tt.fields.msg,
}
if got := e.Error(); got != tt.want {
t.Errorf("ErrAlreadyExists.Error() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,27 +0,0 @@
package apiv1
import (
"context"
"sync"
)
var registry = new(sync.Map)
// KeyManagerNewFunc is the type that represents the method to initialize a new
// KeyManager.
type KeyManagerNewFunc func(ctx context.Context, opts Options) (KeyManager, error)
// Register adds to the registry a method to create a KeyManager of type t.
func Register(t Type, fn KeyManagerNewFunc) {
registry.Store(t, fn)
}
// LoadKeyManagerNewFunc returns the function initialize a KayManager.
func LoadKeyManagerNewFunc(t Type) (KeyManagerNewFunc, bool) {
v, ok := registry.Load(t)
if !ok {
return nil, false
}
fn, ok := v.(KeyManagerNewFunc)
return fn, ok
}

View file

@ -1,167 +0,0 @@
package apiv1
import (
"crypto"
"crypto/x509"
"fmt"
)
// ProtectionLevel specifies on some KMS how cryptographic operations are
// performed.
type ProtectionLevel int
const (
// Protection level not specified.
UnspecifiedProtectionLevel ProtectionLevel = iota
// Crypto operations are performed in software.
Software
// Crypto operations are performed in a Hardware Security Module.
HSM
)
// String returns a string representation of p.
func (p ProtectionLevel) String() string {
switch p {
case UnspecifiedProtectionLevel:
return "unspecified"
case Software:
return "software"
case HSM:
return "hsm"
default:
return fmt.Sprintf("unknown(%d)", p)
}
}
// SignatureAlgorithm used for cryptographic signing.
type SignatureAlgorithm int
const (
// Not specified.
UnspecifiedSignAlgorithm SignatureAlgorithm = iota
// RSASSA-PKCS1-v1_5 key and a SHA256 digest.
SHA256WithRSA
// RSASSA-PKCS1-v1_5 key and a SHA384 digest.
SHA384WithRSA
// RSASSA-PKCS1-v1_5 key and a SHA512 digest.
SHA512WithRSA
// RSASSA-PSS key with a SHA256 digest.
SHA256WithRSAPSS
// RSASSA-PSS key with a SHA384 digest.
SHA384WithRSAPSS
// RSASSA-PSS key with a SHA512 digest.
SHA512WithRSAPSS
// ECDSA on the NIST P-256 curve with a SHA256 digest.
ECDSAWithSHA256
// ECDSA on the NIST P-384 curve with a SHA384 digest.
ECDSAWithSHA384
// ECDSA on the NIST P-521 curve with a SHA512 digest.
ECDSAWithSHA512
// EdDSA on Curve25519 with a SHA512 digest.
PureEd25519
)
// String returns a string representation of s.
func (s SignatureAlgorithm) String() string {
switch s {
case UnspecifiedSignAlgorithm:
return "unspecified"
case SHA256WithRSA:
return "SHA256-RSA"
case SHA384WithRSA:
return "SHA384-RSA"
case SHA512WithRSA:
return "SHA512-RSA"
case SHA256WithRSAPSS:
return "SHA256-RSAPSS"
case SHA384WithRSAPSS:
return "SHA384-RSAPSS"
case SHA512WithRSAPSS:
return "SHA512-RSAPSS"
case ECDSAWithSHA256:
return "ECDSA-SHA256"
case ECDSAWithSHA384:
return "ECDSA-SHA384"
case ECDSAWithSHA512:
return "ECDSA-SHA512"
case PureEd25519:
return "Ed25519"
default:
return fmt.Sprintf("unknown(%d)", s)
}
}
// GetPublicKeyRequest is the parameter used in the kms.GetPublicKey method.
type GetPublicKeyRequest struct {
Name string
}
// CreateKeyRequest is the parameter used in the kms.CreateKey method.
type CreateKeyRequest struct {
// Name represents the key name or label used to identify a key.
//
// Used by: awskms, cloudkms, azurekms, pkcs11, yubikey.
Name string
// SignatureAlgorithm represents the type of key to create.
SignatureAlgorithm SignatureAlgorithm
// Bits is the number of bits on RSA keys.
Bits int
// ProtectionLevel specifies how cryptographic operations are performed.
// Used by: cloudkms, azurekms.
ProtectionLevel ProtectionLevel
// Extractable defines if the new key may be exported from the HSM under a
// wrap key. On pkcs11 sets the CKA_EXTRACTABLE bit.
//
// Used by: pkcs11
Extractable bool
}
// CreateKeyResponse is the response value of the kms.CreateKey method.
type CreateKeyResponse struct {
Name string
PublicKey crypto.PublicKey
PrivateKey crypto.PrivateKey
CreateSignerRequest CreateSignerRequest
}
// CreateSignerRequest is the parameter used in the kms.CreateSigner method.
type CreateSignerRequest struct {
Signer crypto.Signer
SigningKey string
SigningKeyPEM []byte
TokenLabel string
PublicKey string
PublicKeyPEM []byte
Password []byte
}
// CreateDecrypterRequest is the parameter used in the kms.Decrypt method.
type CreateDecrypterRequest struct {
Decrypter crypto.Decrypter
DecryptionKey string
DecryptionKeyPEM []byte
Password []byte
}
// LoadCertificateRequest is the parameter used in the LoadCertificate method of
// a CertificateManager.
type LoadCertificateRequest struct {
Name string
}
// StoreCertificateRequest is the parameter used in the StoreCertificate method
// of a CertificateManager.
type StoreCertificateRequest struct {
Name string
Certificate *x509.Certificate
// Extractable defines if the new certificate may be exported from the HSM
// under a wrap key. On pkcs11 sets the CKA_EXTRACTABLE bit.
//
// Used by: pkcs11
Extractable bool
}

View file

@ -1,51 +0,0 @@
package apiv1
import "testing"
func TestProtectionLevel_String(t *testing.T) {
tests := []struct {
name string
p ProtectionLevel
want string
}{
{"unspecified", UnspecifiedProtectionLevel, "unspecified"},
{"software", Software, "software"},
{"hsm", HSM, "hsm"},
{"unknown", ProtectionLevel(100), "unknown(100)"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.p.String(); got != tt.want {
t.Errorf("ProtectionLevel.String() = %v, want %v", got, tt.want)
}
})
}
}
func TestSignatureAlgorithm_String(t *testing.T) {
tests := []struct {
name string
s SignatureAlgorithm
want string
}{
{"UnspecifiedSignAlgorithm", UnspecifiedSignAlgorithm, "unspecified"},
{"SHA256WithRSA", SHA256WithRSA, "SHA256-RSA"},
{"SHA384WithRSA", SHA384WithRSA, "SHA384-RSA"},
{"SHA512WithRSA", SHA512WithRSA, "SHA512-RSA"},
{"SHA256WithRSAPSS", SHA256WithRSAPSS, "SHA256-RSAPSS"},
{"SHA384WithRSAPSS", SHA384WithRSAPSS, "SHA384-RSAPSS"},
{"SHA512WithRSAPSS", SHA512WithRSAPSS, "SHA512-RSAPSS"},
{"ECDSAWithSHA256", ECDSAWithSHA256, "ECDSA-SHA256"},
{"ECDSAWithSHA384", ECDSAWithSHA384, "ECDSA-SHA384"},
{"ECDSAWithSHA512", ECDSAWithSHA512, "ECDSA-SHA512"},
{"PureEd25519", PureEd25519, "Ed25519"},
{"unknown", SignatureAlgorithm(100), "unknown(100)"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.s.String(); got != tt.want {
t.Errorf("SignatureAlgorithm.String() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,267 +0,0 @@
package awskms
import (
"context"
"crypto"
"net/url"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/kms/uri"
"go.step.sm/crypto/pemutil"
)
// Scheme is the scheme used in uris.
const Scheme = "awskms"
// KMS implements a KMS using AWS Key Management Service.
type KMS struct {
session *session.Session
service KeyManagementClient
}
// KeyManagementClient defines the methods on KeyManagementClient that this
// package will use. This interface will be used for unit testing.
type KeyManagementClient interface {
GetPublicKeyWithContext(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error)
CreateKeyWithContext(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error)
CreateAliasWithContext(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error)
SignWithContext(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error)
}
// customerMasterKeySpecMapping is a mapping between the step signature algorithm,
// and bits for RSA keys, with awskms CustomerMasterKeySpec.
var customerMasterKeySpecMapping = map[apiv1.SignatureAlgorithm]interface{}{
apiv1.UnspecifiedSignAlgorithm: kms.CustomerMasterKeySpecEccNistP256,
apiv1.SHA256WithRSA: map[int]string{
0: kms.CustomerMasterKeySpecRsa3072,
2048: kms.CustomerMasterKeySpecRsa2048,
3072: kms.CustomerMasterKeySpecRsa3072,
4096: kms.CustomerMasterKeySpecRsa4096,
},
apiv1.SHA512WithRSA: map[int]string{
0: kms.CustomerMasterKeySpecRsa4096,
4096: kms.CustomerMasterKeySpecRsa4096,
},
apiv1.SHA256WithRSAPSS: map[int]string{
0: kms.CustomerMasterKeySpecRsa3072,
2048: kms.CustomerMasterKeySpecRsa2048,
3072: kms.CustomerMasterKeySpecRsa3072,
4096: kms.CustomerMasterKeySpecRsa4096,
},
apiv1.SHA512WithRSAPSS: map[int]string{
0: kms.CustomerMasterKeySpecRsa4096,
4096: kms.CustomerMasterKeySpecRsa4096,
},
apiv1.ECDSAWithSHA256: kms.CustomerMasterKeySpecEccNistP256,
apiv1.ECDSAWithSHA384: kms.CustomerMasterKeySpecEccNistP384,
apiv1.ECDSAWithSHA512: kms.CustomerMasterKeySpecEccNistP521,
}
// New creates a new AWSKMS. By default, sessions will be created using the
// credentials in `~/.aws/credentials`, but this can be overridden using the
// CredentialsFile option, the Region and Profile can also be configured as
// options.
//
// AWS sessions can also be configured with environment variables, see docs at
// https://docs.aws.amazon.com/sdk-for-go/api/aws/session/ for all the options.
func New(ctx context.Context, opts apiv1.Options) (*KMS, error) {
var o session.Options
if opts.URI != "" {
u, err := uri.ParseWithScheme(Scheme, opts.URI)
if err != nil {
return nil, err
}
o.Profile = u.Get("profile")
if v := u.Get("region"); v != "" {
o.Config.Region = new(string)
*o.Config.Region = v
}
if f := u.Get("credentials-file"); f != "" {
o.SharedConfigFiles = []string{f}
}
}
// Deprecated way to set configuration parameters.
if opts.Region != "" {
o.Config.Region = &opts.Region
}
if opts.Profile != "" {
o.Profile = opts.Profile
}
if opts.CredentialsFile != "" {
o.SharedConfigFiles = []string{opts.CredentialsFile}
}
sess, err := session.NewSessionWithOptions(o)
if err != nil {
return nil, errors.Wrap(err, "error creating AWS session")
}
return &KMS{
session: sess,
service: kms.New(sess),
}, nil
}
func init() {
apiv1.Register(apiv1.AmazonKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
return New(ctx, opts)
})
}
// GetPublicKey returns a public key from KMS.
func (k *KMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
if req.Name == "" {
return nil, errors.New("getPublicKey 'name' cannot be empty")
}
keyID, err := parseKeyID(req.Name)
if err != nil {
return nil, err
}
ctx, cancel := defaultContext()
defer cancel()
resp, err := k.service.GetPublicKeyWithContext(ctx, &kms.GetPublicKeyInput{
KeyId: &keyID,
})
if err != nil {
return nil, errors.Wrap(err, "awskms GetPublicKeyWithContext failed")
}
return pemutil.ParseDER(resp.PublicKey)
}
// CreateKey generates a new key in KMS and returns the public key version
// of it.
func (k *KMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
if req.Name == "" {
return nil, errors.New("createKeyRequest 'name' cannot be empty")
}
keySpec, err := getCustomerMasterKeySpecMapping(req.SignatureAlgorithm, req.Bits)
if err != nil {
return nil, err
}
tag := new(kms.Tag)
tag.SetTagKey("name")
tag.SetTagValue(req.Name)
input := &kms.CreateKeyInput{
Description: &req.Name,
CustomerMasterKeySpec: &keySpec,
Tags: []*kms.Tag{tag},
}
input.SetKeyUsage(kms.KeyUsageTypeSignVerify)
ctx, cancel := defaultContext()
defer cancel()
resp, err := k.service.CreateKeyWithContext(ctx, input)
if err != nil {
return nil, errors.Wrap(err, "awskms CreateKeyWithContext failed")
}
if err := k.createKeyAlias(*resp.KeyMetadata.KeyId, req.Name); err != nil {
return nil, err
}
// Create uri for key
name := uri.New("awskms", url.Values{
"key-id": []string{*resp.KeyMetadata.KeyId},
}).String()
publicKey, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
Name: name,
})
if err != nil {
return nil, err
}
// Names uses Amazon Resource Name
// https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html
return &apiv1.CreateKeyResponse{
Name: name,
PublicKey: publicKey,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: name,
},
}, nil
}
func (k *KMS) createKeyAlias(keyID, alias string) error {
alias = "alias/" + alias + "-" + keyID[:8]
ctx, cancel := defaultContext()
defer cancel()
_, err := k.service.CreateAliasWithContext(ctx, &kms.CreateAliasInput{
AliasName: &alias,
TargetKeyId: &keyID,
})
if err != nil {
return errors.Wrap(err, "awskms CreateAliasWithContext failed")
}
return nil
}
// CreateSigner creates a new crypto.Signer with a previously configured key.
func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
if req.SigningKey == "" {
return nil, errors.New("createSigner 'signingKey' cannot be empty")
}
return NewSigner(k.service, req.SigningKey)
}
// Close closes the connection of the KMS client.
func (k *KMS) Close() error {
return nil
}
func defaultContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 15*time.Second)
}
// parseKeyID extracts the key-id from an uri.
func parseKeyID(name string) (string, error) {
name = strings.ToLower(name)
if strings.HasPrefix(name, "awskms:") || strings.HasPrefix(name, "aws:") {
u, err := uri.Parse(name)
if err != nil {
return "", err
}
if k := u.Get("key-id"); k != "" {
return k, nil
}
return "", errors.Errorf("failed to get key-id from %s", name)
}
return name, nil
}
func getCustomerMasterKeySpecMapping(alg apiv1.SignatureAlgorithm, bits int) (string, error) {
v, ok := customerMasterKeySpecMapping[alg]
if !ok {
return "", errors.Errorf("awskms does not support signature algorithm '%s'", alg)
}
switch v := v.(type) {
case string:
return v, nil
case map[int]string:
s, ok := v[bits]
if !ok {
return "", errors.Errorf("awskms does not support signature algorithm '%s' with '%d' bits", alg, bits)
}
return s, nil
default:
return "", errors.Errorf("unexpected error: this should not happen")
}
}

View file

@ -1,364 +0,0 @@
package awskms
import (
"context"
"crypto"
"fmt"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/smallstep/certificates/kms/apiv1"
"go.step.sm/crypto/pemutil"
)
func TestNew(t *testing.T) {
ctx := context.Background()
sess, err := session.NewSessionWithOptions(session.Options{})
if err != nil {
t.Fatal(err)
}
expected := &KMS{
session: sess,
service: kms.New(sess),
}
// This will force an error in the session creation.
// It does not fail with missing credentials.
forceError := func(t *testing.T) {
key := "AWS_CA_BUNDLE"
value := os.Getenv(key)
os.Setenv(key, filepath.Join(os.TempDir(), "missing-ca.crt"))
t.Cleanup(func() {
if value == "" {
os.Unsetenv(key)
} else {
os.Setenv(key, value)
}
})
}
type args struct {
ctx context.Context
opts apiv1.Options
}
tests := []struct {
name string
args args
want *KMS
wantErr bool
}{
{"ok", args{ctx, apiv1.Options{}}, expected, false},
{"ok with options", args{ctx, apiv1.Options{
Region: "us-east-1",
Profile: "smallstep",
CredentialsFile: "~/aws/credentials",
}}, expected, false},
{"ok with uri", args{ctx, apiv1.Options{
URI: "awskms:region=us-east-1;profile=smallstep;credentials-file=/var/run/aws/credentials",
}}, expected, false},
{"fail", args{ctx, apiv1.Options{}}, nil, true},
{"fail uri", args{ctx, apiv1.Options{
URI: "pkcs11:region=us-east-1;profile=smallstep;credentials-file=/var/run/aws/credentials",
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Force an error in the session loading
if tt.wantErr {
forceError(t)
}
got, err := New(tt.args.ctx, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("New() = %#v, want %#v", got, tt.want)
}
} else {
if got.session == nil || got.service == nil {
t.Errorf("New() = %#v, want %#v", got, tt.want)
}
}
})
}
}
func TestKMS_GetPublicKey(t *testing.T) {
okClient := getOKClient()
key, err := pemutil.ParseKey([]byte(publicKey))
if err != nil {
t.Fatal(err)
}
type fields struct {
session *session.Session
service KeyManagementClient
}
type args struct {
req *apiv1.GetPublicKeyRequest
}
tests := []struct {
name string
fields fields
args args
want crypto.PublicKey
wantErr bool
}{
{"ok", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
}}, key, false},
{"fail empty", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{}}, nil, true},
{"fail name", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{
Name: "awskms:key-id=",
}}, nil, true},
{"fail getPublicKey", fields{nil, &MockClient{
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
return nil, fmt.Errorf("an error")
},
}}, args{&apiv1.GetPublicKeyRequest{
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
}}, nil, true},
{"fail not der", fields{nil, &MockClient{
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
return &kms.GetPublicKeyOutput{
KeyId: input.KeyId,
PublicKey: []byte(publicKey),
}, nil
},
}}, args{&apiv1.GetPublicKeyRequest{
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &KMS{
session: tt.fields.session,
service: tt.fields.service,
}
got, err := k.GetPublicKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("KMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("KMS.GetPublicKey() = %v, want %v", got, tt.want)
}
})
}
}
func TestKMS_CreateKey(t *testing.T) {
okClient := getOKClient()
key, err := pemutil.ParseKey([]byte(publicKey))
if err != nil {
t.Fatal(err)
}
type fields struct {
session *session.Session
service KeyManagementClient
}
type args struct {
req *apiv1.CreateKeyRequest
}
tests := []struct {
name string
fields fields
args args
want *apiv1.CreateKeyResponse
wantErr bool
}{
{"ok", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
Name: "root",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
}}, &apiv1.CreateKeyResponse{
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
PublicKey: key,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
},
}, false},
{"ok rsa", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
Name: "root",
SignatureAlgorithm: apiv1.SHA256WithRSA,
Bits: 2048,
}}, &apiv1.CreateKeyResponse{
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
PublicKey: key,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
},
}, false},
{"fail empty", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{}}, nil, true},
{"fail unsupported alg", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
Name: "root",
SignatureAlgorithm: apiv1.PureEd25519,
}}, nil, true},
{"fail unsupported bits", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
Name: "root",
SignatureAlgorithm: apiv1.SHA256WithRSA,
Bits: 1234,
}}, nil, true},
{"fail createKey", fields{nil, &MockClient{
createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
return nil, fmt.Errorf("an error")
},
createAliasWithContext: okClient.createAliasWithContext,
getPublicKeyWithContext: okClient.getPublicKeyWithContext,
}}, args{&apiv1.CreateKeyRequest{
Name: "root",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
}}, nil, true},
{"fail createAlias", fields{nil, &MockClient{
createKeyWithContext: okClient.createKeyWithContext,
createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
return nil, fmt.Errorf("an error")
},
getPublicKeyWithContext: okClient.getPublicKeyWithContext,
}}, args{&apiv1.CreateKeyRequest{
Name: "root",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
}}, nil, true},
{"fail getPublicKey", fields{nil, &MockClient{
createKeyWithContext: okClient.createKeyWithContext,
createAliasWithContext: okClient.createAliasWithContext,
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
return nil, fmt.Errorf("an error")
},
}}, args{&apiv1.CreateKeyRequest{
Name: "root",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &KMS{
session: tt.fields.session,
service: tt.fields.service,
}
got, err := k.CreateKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("KMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("KMS.CreateKey() = %v, want %v", got, tt.want)
}
})
}
}
func TestKMS_CreateSigner(t *testing.T) {
client := getOKClient()
key, err := pemutil.ParseKey([]byte(publicKey))
if err != nil {
t.Fatal(err)
}
type fields struct {
session *session.Session
service KeyManagementClient
}
type args struct {
req *apiv1.CreateSignerRequest
}
tests := []struct {
name string
fields fields
args args
want crypto.Signer
wantErr bool
}{
{"ok", fields{nil, client}, args{&apiv1.CreateSignerRequest{
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
}}, &Signer{
service: client,
keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936",
publicKey: key,
}, false},
{"fail empty", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true},
{"fail preload", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &KMS{
session: tt.fields.session,
service: tt.fields.service,
}
got, err := k.CreateSigner(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("KMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("KMS.CreateSigner() = %v, want %v", got, tt.want)
}
})
}
}
func TestKMS_Close(t *testing.T) {
type fields struct {
session *session.Session
service KeyManagementClient
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{"ok", fields{nil, getOKClient()}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &KMS{
session: tt.fields.session,
service: tt.fields.service,
}
if err := k.Close(); (err != nil) != tt.wantErr {
t.Errorf("KMS.Close() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_parseKeyID(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{"ok uri", args{"awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
{"ok key id", args{"be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
{"ok arn", args{"arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
{"fail parse", args{"awskms:key-id=%ZZ"}, "", true},
{"fail empty key", args{"awskms:key-id="}, "", true},
{"fail missing", args{"awskms:foo=bar"}, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseKeyID(tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("parseKeyID() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("parseKeyID() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,72 +0,0 @@
package awskms
import (
"encoding/pem"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/kms"
)
type MockClient struct {
getPublicKeyWithContext func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error)
createKeyWithContext func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error)
createAliasWithContext func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error)
signWithContext func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error)
}
func (m *MockClient) GetPublicKeyWithContext(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
return m.getPublicKeyWithContext(ctx, input, opts...)
}
func (m *MockClient) CreateKeyWithContext(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
return m.createKeyWithContext(ctx, input, opts...)
}
func (m *MockClient) CreateAliasWithContext(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
return m.createAliasWithContext(ctx, input, opts...)
}
func (m *MockClient) SignWithContext(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
return m.signWithContext(ctx, input, opts...)
}
const (
publicKey = `-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE8XWlIWkOThxNjGbZLYUgRHmsvCrW
KF+HLktPfPTIK3lGd1k4849WQs59XIN+LXZQ6b2eRBEBKAHEyQus8UU7gw==
-----END PUBLIC KEY-----`
keyID = "be468355-ca7a-40d9-a28b-8ae1c4c7f936"
)
var signature = []byte{
0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24,
0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55,
}
func getOKClient() *MockClient {
return &MockClient{
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
block, _ := pem.Decode([]byte(publicKey))
return &kms.GetPublicKeyOutput{
KeyId: input.KeyId,
PublicKey: block.Bytes,
}, nil
},
createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
md := new(kms.KeyMetadata)
md.SetKeyId(keyID)
return &kms.CreateKeyOutput{
KeyMetadata: md,
}, nil
},
createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
return &kms.CreateAliasOutput{}, nil
},
signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
return &kms.SignOutput{
Signature: signature,
}, nil
},
}
}

View file

@ -1,122 +0,0 @@
package awskms
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"io"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/pkg/errors"
"go.step.sm/crypto/pemutil"
)
// Signer implements a crypto.Signer using the AWS KMS.
type Signer struct {
service KeyManagementClient
keyID string
publicKey crypto.PublicKey
}
// NewSigner creates a new signer using a key in the AWS KMS.
func NewSigner(svc KeyManagementClient, signingKey string) (*Signer, error) {
keyID, err := parseKeyID(signingKey)
if err != nil {
return nil, err
}
// Make sure that the key exists.
signer := &Signer{
service: svc,
keyID: keyID,
}
if err := signer.preloadKey(keyID); err != nil {
return nil, err
}
return signer, nil
}
func (s *Signer) preloadKey(keyID string) error {
ctx, cancel := defaultContext()
defer cancel()
resp, err := s.service.GetPublicKeyWithContext(ctx, &kms.GetPublicKeyInput{
KeyId: &keyID,
})
if err != nil {
return errors.Wrap(err, "awskms GetPublicKeyWithContext failed")
}
s.publicKey, err = pemutil.ParseDER(resp.PublicKey)
return err
}
// Public returns the public key of this signer or an error.
func (s *Signer) Public() crypto.PublicKey {
return s.publicKey
}
// Sign signs digest with the private key stored in the AWS KMS.
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
alg, err := getSigningAlgorithm(s.Public(), opts)
if err != nil {
return nil, err
}
req := &kms.SignInput{
KeyId: &s.keyID,
SigningAlgorithm: &alg,
Message: digest,
}
req.SetMessageType("DIGEST")
ctx, cancel := defaultContext()
defer cancel()
resp, err := s.service.SignWithContext(ctx, req)
if err != nil {
return nil, errors.Wrap(err, "awsKMS SignWithContext failed")
}
return resp.Signature, nil
}
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (string, error) {
switch key.(type) {
case *rsa.PublicKey:
_, isPSS := opts.(*rsa.PSSOptions)
switch h := opts.HashFunc(); h {
case crypto.SHA256:
if isPSS {
return kms.SigningAlgorithmSpecRsassaPssSha256, nil
}
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil
case crypto.SHA384:
if isPSS {
return kms.SigningAlgorithmSpecRsassaPssSha384, nil
}
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil
case crypto.SHA512:
if isPSS {
return kms.SigningAlgorithmSpecRsassaPssSha512, nil
}
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil
default:
return "", errors.Errorf("unsupported hash function %v", h)
}
case *ecdsa.PublicKey:
switch h := opts.HashFunc(); h {
case crypto.SHA256:
return kms.SigningAlgorithmSpecEcdsaSha256, nil
case crypto.SHA384:
return kms.SigningAlgorithmSpecEcdsaSha384, nil
case crypto.SHA512:
return kms.SigningAlgorithmSpecEcdsaSha512, nil
default:
return "", errors.Errorf("unsupported hash function %v", h)
}
default:
return "", errors.Errorf("unsupported key type %T", key)
}
}

View file

@ -1,191 +0,0 @@
package awskms
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"fmt"
"io"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/kms"
"go.step.sm/crypto/pemutil"
)
func TestNewSigner(t *testing.T) {
okClient := getOKClient()
key, err := pemutil.ParseKey([]byte(publicKey))
if err != nil {
t.Fatal(err)
}
type args struct {
svc KeyManagementClient
signingKey string
}
tests := []struct {
name string
args args
want *Signer
wantErr bool
}{
{"ok", args{okClient, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, &Signer{
service: okClient,
keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936",
publicKey: key,
}, false},
{"fail parse", args{okClient, "awskms:key-id="}, nil, true},
{"fail preload", args{&MockClient{
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
return nil, fmt.Errorf("an error")
},
}, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true},
{"fail preload not der", args{&MockClient{
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
return &kms.GetPublicKeyOutput{
KeyId: input.KeyId,
PublicKey: []byte(publicKey),
}, nil
},
}, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewSigner(tt.args.svc, tt.args.signingKey)
if (err != nil) != tt.wantErr {
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
}
})
}
}
func TestSigner_Public(t *testing.T) {
okClient := getOKClient()
key, err := pemutil.ParseKey([]byte(publicKey))
if err != nil {
t.Fatal(err)
}
type fields struct {
service KeyManagementClient
keyID string
publicKey crypto.PublicKey
}
tests := []struct {
name string
fields fields
want crypto.PublicKey
}{
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, key},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Signer{
service: tt.fields.service,
keyID: tt.fields.keyID,
publicKey: tt.fields.publicKey,
}
if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Signer.Public() = %v, want %v", got, tt.want)
}
})
}
}
func TestSigner_Sign(t *testing.T) {
okClient := getOKClient()
key, err := pemutil.ParseKey([]byte(publicKey))
if err != nil {
t.Fatal(err)
}
type fields struct {
service KeyManagementClient
keyID string
publicKey crypto.PublicKey
}
type args struct {
rand io.Reader
digest []byte
opts crypto.SignerOpts
}
tests := []struct {
name string
fields fields
args args
want []byte
wantErr bool
}{
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, signature, false},
{"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
{"fail key", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", []byte("key")}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
{"fail sign", fields{&MockClient{
signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
return nil, fmt.Errorf("an error")
},
}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Signer{
service: tt.fields.service,
keyID: tt.fields.keyID,
publicKey: tt.fields.publicKey,
}
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
}
})
}
}
func Test_getSigningAlgorithm(t *testing.T) {
type args struct {
key crypto.PublicKey
opts crypto.SignerOpts
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, "RSASSA_PKCS1_V1_5_SHA_256", false},
{"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, "RSASSA_PKCS1_V1_5_SHA_384", false},
{"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, "RSASSA_PKCS1_V1_5_SHA_512", false},
{"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, "RSASSA_PSS_SHA_256", false},
{"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, "RSASSA_PSS_SHA_384", false},
{"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, "RSASSA_PSS_SHA_512", false},
{"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, "ECDSA_SHA_256", false},
{"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, "ECDSA_SHA_384", false},
{"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, "ECDSA_SHA_512", false},
{"fail type", args{[]byte("key"), crypto.SHA256}, "", true},
{"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, "", true},
{"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getSigningAlgorithm(tt.args.key, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("getSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("getSigningAlgorithm() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,81 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/smallstep/certificates/kms/azurekms (interfaces: KeyVaultClient)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
reflect "reflect"
keyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
gomock "github.com/golang/mock/gomock"
)
// KeyVaultClient is a mock of KeyVaultClient interface
type KeyVaultClient struct {
ctrl *gomock.Controller
recorder *KeyVaultClientMockRecorder
}
// KeyVaultClientMockRecorder is the mock recorder for KeyVaultClient
type KeyVaultClientMockRecorder struct {
mock *KeyVaultClient
}
// NewKeyVaultClient creates a new mock instance
func NewKeyVaultClient(ctrl *gomock.Controller) *KeyVaultClient {
mock := &KeyVaultClient{ctrl: ctrl}
mock.recorder = &KeyVaultClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *KeyVaultClient) EXPECT() *KeyVaultClientMockRecorder {
return m.recorder
}
// CreateKey mocks base method
func (m *KeyVaultClient) CreateKey(arg0 context.Context, arg1, arg2 string, arg3 keyvault.KeyCreateParameters) (keyvault.KeyBundle, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateKey", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(keyvault.KeyBundle)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateKey indicates an expected call of CreateKey
func (mr *KeyVaultClientMockRecorder) CreateKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateKey", reflect.TypeOf((*KeyVaultClient)(nil).CreateKey), arg0, arg1, arg2, arg3)
}
// GetKey mocks base method
func (m *KeyVaultClient) GetKey(arg0 context.Context, arg1, arg2, arg3 string) (keyvault.KeyBundle, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKey", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(keyvault.KeyBundle)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetKey indicates an expected call of GetKey
func (mr *KeyVaultClientMockRecorder) GetKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKey", reflect.TypeOf((*KeyVaultClient)(nil).GetKey), arg0, arg1, arg2, arg3)
}
// Sign mocks base method
func (m *KeyVaultClient) Sign(arg0 context.Context, arg1, arg2, arg3 string, arg4 keyvault.KeySignParameters) (keyvault.KeyOperationResult, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Sign", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(keyvault.KeyOperationResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Sign indicates an expected call of Sign
func (mr *KeyVaultClientMockRecorder) Sign(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sign", reflect.TypeOf((*KeyVaultClient)(nil).Sign), arg0, arg1, arg2, arg3, arg4)
}

View file

@ -1,342 +0,0 @@
package azurekms
import (
"context"
"crypto"
"regexp"
"time"
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/Azure/go-autorest/autorest/date"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/kms/uri"
)
func init() {
apiv1.Register(apiv1.AzureKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
return New(ctx, opts)
})
}
// Scheme is the scheme used for the Azure Key Vault uris.
const Scheme = "azurekms"
// keyIDRegexp is the regular expression that Key Vault uses on the kid. We can
// extract the vault, name and version of the key.
var keyIDRegexp = regexp.MustCompile(`^https://([0-9a-zA-Z-]+)\.vault\.azure\.net/keys/([0-9a-zA-Z-]+)/([0-9a-zA-Z-]+)$`)
var (
valueTrue = true
value2048 int32 = 2048
value3072 int32 = 3072
value4096 int32 = 4096
)
var now = func() time.Time {
return time.Now().UTC()
}
type keyType struct {
Kty keyvault.JSONWebKeyType
Curve keyvault.JSONWebKeyCurveName
}
func (k keyType) KeyType(pl apiv1.ProtectionLevel) keyvault.JSONWebKeyType {
switch k.Kty {
case keyvault.EC:
if pl == apiv1.HSM {
return keyvault.ECHSM
}
return k.Kty
case keyvault.RSA:
if pl == apiv1.HSM {
return keyvault.RSAHSM
}
return k.Kty
default:
return ""
}
}
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]keyType{
apiv1.UnspecifiedSignAlgorithm: {
Kty: keyvault.EC,
Curve: keyvault.P256,
},
apiv1.SHA256WithRSA: {
Kty: keyvault.RSA,
},
apiv1.SHA384WithRSA: {
Kty: keyvault.RSA,
},
apiv1.SHA512WithRSA: {
Kty: keyvault.RSA,
},
apiv1.SHA256WithRSAPSS: {
Kty: keyvault.RSA,
},
apiv1.SHA384WithRSAPSS: {
Kty: keyvault.RSA,
},
apiv1.SHA512WithRSAPSS: {
Kty: keyvault.RSA,
},
apiv1.ECDSAWithSHA256: {
Kty: keyvault.EC,
Curve: keyvault.P256,
},
apiv1.ECDSAWithSHA384: {
Kty: keyvault.EC,
Curve: keyvault.P384,
},
apiv1.ECDSAWithSHA512: {
Kty: keyvault.EC,
Curve: keyvault.P521,
},
}
// vaultResource is the value the client will use as audience.
const vaultResource = "https://vault.azure.net"
// KeyVaultClient is the interface implemented by keyvault.BaseClient. It will
// be used for testing purposes.
type KeyVaultClient interface {
GetKey(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string) (keyvault.KeyBundle, error)
CreateKey(ctx context.Context, vaultBaseURL string, keyName string, parameters keyvault.KeyCreateParameters) (keyvault.KeyBundle, error)
Sign(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string, parameters keyvault.KeySignParameters) (keyvault.KeyOperationResult, error)
}
// KeyVault implements a KMS using Azure Key Vault.
//
// The URI format used in Azure Key Vault is the following:
//
// - azurekms:name=key-name;vault=vault-name
// - azurekms:name=key-name;vault=vault-name?version=key-version
// - azurekms:name=key-name;vault=vault-name?hsm=true
//
// The scheme is "azurekms"; "name" is the key name; "vault" is the key vault
// name where the key is located; "version" is an optional parameter that
// defines the version of they key, if version is not given, the latest one will
// be used; "hsm" defines if an HSM want to be used for this key, this is
// specially useful when this is used from `step`.
//
// TODO(mariano): The implementation is using /services/keyvault/v7.1/keyvault
// package, at some point Azure might create a keyvault client with all the
// functionality in /sdk/keyvault, we should migrate to that once available.
type KeyVault struct {
baseClient KeyVaultClient
defaults DefaultOptions
}
// DefaultOptions are custom options that can be passed as defaults using the
// URI in apiv1.Options.
type DefaultOptions struct {
Vault string
ProtectionLevel apiv1.ProtectionLevel
}
var createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
baseClient := keyvault.New()
// With an URI, try to log in only using client credentials in the URI.
// Client credentials requires:
// - client-id
// - client-secret
// - tenant-id
// And optionally the aad-endpoint to support custom clouds:
// - aad-endpoint (defaults to https://login.microsoftonline.com/)
if opts.URI != "" {
u, err := uri.ParseWithScheme(Scheme, opts.URI)
if err != nil {
return nil, err
}
// Required options
clientID := u.Get("client-id")
clientSecret := u.Get("client-secret")
tenantID := u.Get("tenant-id")
// optional
aadEndpoint := u.Get("aad-endpoint")
if clientID != "" && clientSecret != "" && tenantID != "" {
s := auth.EnvironmentSettings{
Values: map[string]string{
auth.ClientID: clientID,
auth.ClientSecret: clientSecret,
auth.TenantID: tenantID,
auth.Resource: vaultResource,
},
Environment: azure.PublicCloud,
}
if aadEndpoint != "" {
s.Environment.ActiveDirectoryEndpoint = aadEndpoint
}
baseClient.Authorizer, err = s.GetAuthorizer()
if err != nil {
return nil, err
}
return baseClient, nil
}
}
// Attempt to authorize with the following methods:
// 1. Environment variables.
// - Client credentials
// - Client certificate
// - Username and password
// - MSI
// 2. Using Azure CLI 2.0 on local development.
authorizer, err := auth.NewAuthorizerFromEnvironmentWithResource(vaultResource)
if err != nil {
authorizer, err = auth.NewAuthorizerFromCLIWithResource(vaultResource)
if err != nil {
return nil, errors.Wrap(err, "error getting authorizer for key vault")
}
}
baseClient.Authorizer = authorizer
return &baseClient, nil
}
// New initializes a new KMS implemented using Azure Key Vault.
func New(ctx context.Context, opts apiv1.Options) (*KeyVault, error) {
baseClient, err := createClient(ctx, opts)
if err != nil {
return nil, err
}
// step and step-ca do not need and URI, but having a default vault and
// protection level is useful if this package is used as an api
var defaults DefaultOptions
if opts.URI != "" {
u, err := uri.ParseWithScheme(Scheme, opts.URI)
if err != nil {
return nil, err
}
defaults.Vault = u.Get("vault")
if u.GetBool("hsm") {
defaults.ProtectionLevel = apiv1.HSM
}
}
return &KeyVault{
baseClient: baseClient,
defaults: defaults,
}, nil
}
// GetPublicKey loads a public key from Azure Key Vault by its resource name.
func (k *KeyVault) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
if req.Name == "" {
return nil, errors.New("getPublicKeyRequest 'name' cannot be empty")
}
vault, name, version, _, err := parseKeyName(req.Name, k.defaults)
if err != nil {
return nil, err
}
ctx, cancel := defaultContext()
defer cancel()
resp, err := k.baseClient.GetKey(ctx, vaultBaseURL(vault), name, version)
if err != nil {
return nil, errors.Wrap(err, "keyVault GetKey failed")
}
return convertKey(resp.Key)
}
// CreateKey creates a asymmetric key in Azure Key Vault.
func (k *KeyVault) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
if req.Name == "" {
return nil, errors.New("createKeyRequest 'name' cannot be empty")
}
vault, name, _, hsm, err := parseKeyName(req.Name, k.defaults)
if err != nil {
return nil, err
}
// Override protection level to HSM only if it's not specified, and is given
// in the uri.
protectionLevel := req.ProtectionLevel
if protectionLevel == apiv1.UnspecifiedProtectionLevel && hsm {
protectionLevel = apiv1.HSM
}
kt, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
if !ok {
return nil, errors.Errorf("keyVault does not support signature algorithm '%s'", req.SignatureAlgorithm)
}
var keySize *int32
if kt.Kty == keyvault.RSA || kt.Kty == keyvault.RSAHSM {
switch req.Bits {
case 2048:
keySize = &value2048
case 0, 3072:
keySize = &value3072
case 4096:
keySize = &value4096
default:
return nil, errors.Errorf("keyVault does not support key size %d", req.Bits)
}
}
created := date.UnixTime(now())
ctx, cancel := defaultContext()
defer cancel()
resp, err := k.baseClient.CreateKey(ctx, vaultBaseURL(vault), name, keyvault.KeyCreateParameters{
Kty: kt.KeyType(protectionLevel),
KeySize: keySize,
Curve: kt.Curve,
KeyOps: &[]keyvault.JSONWebKeyOperation{
keyvault.Sign, keyvault.Verify,
},
KeyAttributes: &keyvault.KeyAttributes{
Enabled: &valueTrue,
Created: &created,
NotBefore: &created,
},
})
if err != nil {
return nil, errors.Wrap(err, "keyVault CreateKey failed")
}
publicKey, err := convertKey(resp.Key)
if err != nil {
return nil, err
}
keyURI := getKeyName(vault, name, resp)
return &apiv1.CreateKeyResponse{
Name: keyURI,
PublicKey: publicKey,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: keyURI,
},
}, nil
}
// CreateSigner returns a crypto.Signer from a previously created asymmetric key.
func (k *KeyVault) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
if req.SigningKey == "" {
return nil, errors.New("createSignerRequest 'signingKey' cannot be empty")
}
return NewSigner(k.baseClient, req.SigningKey, k.defaults)
}
// Close closes the client connection to the Azure Key Vault. This is a noop.
func (k *KeyVault) Close() error {
return nil
}
// ValidateName validates that the given string is a valid URI.
func (k *KeyVault) ValidateName(s string) error {
_, _, _, _, err := parseKeyName(s, k.defaults)
return err
}

View file

@ -1,653 +0,0 @@
//go:generate mockgen -package mock -mock_names=KeyVaultClient=KeyVaultClient -destination internal/mock/key_vault_client.go github.com/smallstep/certificates/kms/azurekms KeyVaultClient
package azurekms
import (
"context"
"crypto"
"encoding/json"
"fmt"
"reflect"
"testing"
"time"
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/Azure/go-autorest/autorest/date"
"github.com/golang/mock/gomock"
"github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/kms/azurekms/internal/mock"
"go.step.sm/crypto/keyutil"
"gopkg.in/square/go-jose.v2"
)
var errTest = fmt.Errorf("test error")
func mockNow(t *testing.T) time.Time {
old := now
t0 := time.Unix(1234567890, 123).UTC()
now = func() time.Time {
return t0
}
t.Cleanup(func() {
now = old
})
return t0
}
func mockClient(t *testing.T) *mock.KeyVaultClient {
t.Helper()
ctrl := gomock.NewController(t)
t.Cleanup(func() {
ctrl.Finish()
})
return mock.NewKeyVaultClient(ctrl)
}
func createJWK(t *testing.T, pub crypto.PublicKey) *keyvault.JSONWebKey {
t.Helper()
b, err := json.Marshal(&jose.JSONWebKey{
Key: pub,
})
if err != nil {
t.Fatal(err)
}
key := new(keyvault.JSONWebKey)
if err := json.Unmarshal(b, key); err != nil {
t.Fatal(err)
}
return key
}
func Test_now(t *testing.T) {
t0 := now()
if loc := t0.Location(); loc != time.UTC {
t.Errorf("now() Location = %v, want %v", loc, time.UTC)
}
}
func TestNew(t *testing.T) {
client := mockClient(t)
old := createClient
t.Cleanup(func() {
createClient = old
})
type args struct {
ctx context.Context
opts apiv1.Options
}
tests := []struct {
name string
setup func()
args args
want *KeyVault
wantErr bool
}{
{"ok", func() {
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
return client, nil
}
}, args{context.Background(), apiv1.Options{}}, &KeyVault{
baseClient: client,
}, false},
{"ok with vault", func() {
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
return client, nil
}
}, args{context.Background(), apiv1.Options{
URI: "azurekms:vault=my-vault",
}}, &KeyVault{
baseClient: client,
defaults: DefaultOptions{
Vault: "my-vault",
ProtectionLevel: apiv1.UnspecifiedProtectionLevel,
},
}, false},
{"ok with vault + hsm", func() {
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
return client, nil
}
}, args{context.Background(), apiv1.Options{
URI: "azurekms:vault=my-vault;hsm=true",
}}, &KeyVault{
baseClient: client,
defaults: DefaultOptions{
Vault: "my-vault",
ProtectionLevel: apiv1.HSM,
},
}, false},
{"fail", func() {
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
return nil, errTest
}
}, args{context.Background(), apiv1.Options{}}, nil, true},
{"fail uri", func() {
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
return client, nil
}
}, args{context.Background(), apiv1.Options{
URI: "kms:vault=my-vault;hsm=true",
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
got, err := New(tt.args.ctx, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("New() = %v, want %v", got, tt.want)
}
})
}
}
func TestKeyVault_createClient(t *testing.T) {
type args struct {
ctx context.Context
opts apiv1.Options
}
tests := []struct {
name string
args args
skip bool
wantErr bool
}{
{"ok", args{context.Background(), apiv1.Options{}}, true, false},
{"ok with uri", args{context.Background(), apiv1.Options{
URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id",
}}, false, false},
{"ok with uri+aad", args{context.Background(), apiv1.Options{
URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id;aad-enpoint=https%3A%2F%2Flogin.microsoftonline.us%2F",
}}, false, false},
{"ok with uri no config", args{context.Background(), apiv1.Options{
URI: "azurekms:",
}}, true, false},
{"fail uri", args{context.Background(), apiv1.Options{
URI: "kms:client-id=id;client-secret=secret;tenant-id=id",
}}, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skip {
t.SkipNow()
}
_, err := createClient(tt.args.ctx, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestKeyVault_GetPublicKey(t *testing.T) {
key, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatal(err)
}
pub := key.Public()
jwk := createJWK(t, pub)
client := mockClient(t)
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
Key: jwk,
}, nil)
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
Key: jwk,
}, nil)
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
type fields struct {
baseClient KeyVaultClient
}
type args struct {
req *apiv1.GetPublicKeyRequest
}
tests := []struct {
name string
fields fields
args args
want crypto.PublicKey
wantErr bool
}{
{"ok", fields{client}, args{&apiv1.GetPublicKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key",
}}, pub, false},
{"ok with version", fields{client}, args{&apiv1.GetPublicKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key?version=my-version",
}}, pub, false},
{"fail GetKey", fields{client}, args{&apiv1.GetPublicKeyRequest{
Name: "azurekms:vault=my-vault;name=not-found?version=my-version",
}}, nil, true},
{"fail empty", fields{client}, args{&apiv1.GetPublicKeyRequest{
Name: "",
}}, nil, true},
{"fail vault", fields{client}, args{&apiv1.GetPublicKeyRequest{
Name: "azurekms:vault=;name=not-found?version=my-version",
}}, nil, true},
{"fail id", fields{client}, args{&apiv1.GetPublicKeyRequest{
Name: "azurekms:vault=;name=?version=my-version",
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &KeyVault{
baseClient: tt.fields.baseClient,
}
got, err := k.GetPublicKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("KeyVault.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("KeyVault.GetPublicKey() = %v, want %v", got, tt.want)
}
})
}
}
func TestKeyVault_CreateKey(t *testing.T) {
ecKey, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatal(err)
}
rsaKey, err := keyutil.GenerateSigner("RSA", "", 2048)
if err != nil {
t.Fatal(err)
}
ecPub := ecKey.Public()
rsaPub := rsaKey.Public()
ecJWK := createJWK(t, ecPub)
rsaJWK := createJWK(t, rsaPub)
t0 := date.UnixTime(mockNow(t))
client := mockClient(t)
expects := []struct {
Name string
Kty keyvault.JSONWebKeyType
KeySize *int32
Curve keyvault.JSONWebKeyCurveName
Key *keyvault.JSONWebKey
}{
{"P-256", keyvault.EC, nil, keyvault.P256, ecJWK},
{"P-256 HSM", keyvault.ECHSM, nil, keyvault.P256, ecJWK},
{"P-256 HSM (uri)", keyvault.ECHSM, nil, keyvault.P256, ecJWK},
{"P-256 Default", keyvault.EC, nil, keyvault.P256, ecJWK},
{"P-384", keyvault.EC, nil, keyvault.P384, ecJWK},
{"P-521", keyvault.EC, nil, keyvault.P521, ecJWK},
{"RSA 0", keyvault.RSA, &value3072, "", rsaJWK},
{"RSA 0 HSM", keyvault.RSAHSM, &value3072, "", rsaJWK},
{"RSA 0 HSM (uri)", keyvault.RSAHSM, &value3072, "", rsaJWK},
{"RSA 2048", keyvault.RSA, &value2048, "", rsaJWK},
{"RSA 3072", keyvault.RSA, &value3072, "", rsaJWK},
{"RSA 4096", keyvault.RSA, &value4096, "", rsaJWK},
}
for _, e := range expects {
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", keyvault.KeyCreateParameters{
Kty: e.Kty,
KeySize: e.KeySize,
Curve: e.Curve,
KeyOps: &[]keyvault.JSONWebKeyOperation{
keyvault.Sign, keyvault.Verify,
},
KeyAttributes: &keyvault.KeyAttributes{
Enabled: &valueTrue,
Created: &t0,
NotBefore: &t0,
},
}).Return(keyvault.KeyBundle{
Key: e.Key,
}, nil)
}
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{}, errTest)
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{
Key: nil,
}, nil)
type fields struct {
baseClient KeyVaultClient
}
type args struct {
req *apiv1.CreateKeyRequest
}
tests := []struct {
name string
fields fields
args args
want *apiv1.CreateKeyResponse
wantErr bool
}{
{"ok P-256", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
ProtectionLevel: apiv1.Software,
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: ecPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"ok P-256 HSM", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
ProtectionLevel: apiv1.HSM,
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: ecPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"ok P-256 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key?hsm=true",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: ecPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"ok P-256 Default", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key",
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: ecPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"ok P-384", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key",
SignatureAlgorithm: apiv1.ECDSAWithSHA384,
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: ecPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"ok P-521", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key",
SignatureAlgorithm: apiv1.ECDSAWithSHA512,
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: ecPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"ok RSA 0", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key",
Bits: 0,
SignatureAlgorithm: apiv1.SHA256WithRSA,
ProtectionLevel: apiv1.Software,
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: rsaPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"ok RSA 0 HSM", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key",
Bits: 0,
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
ProtectionLevel: apiv1.HSM,
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: rsaPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"ok RSA 0 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key;hsm=true",
Bits: 0,
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: rsaPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"ok RSA 2048", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key",
Bits: 2048,
SignatureAlgorithm: apiv1.SHA384WithRSA,
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: rsaPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"ok RSA 3072", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key",
Bits: 3072,
SignatureAlgorithm: apiv1.SHA512WithRSA,
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: rsaPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"ok RSA 4096", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=my-key",
Bits: 4096,
SignatureAlgorithm: apiv1.SHA512WithRSAPSS,
}}, &apiv1.CreateKeyResponse{
Name: "azurekms:name=my-key;vault=my-vault",
PublicKey: rsaPub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: "azurekms:name=my-key;vault=my-vault",
},
}, false},
{"fail createKey", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=not-found",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
}}, nil, true},
{"fail convertKey", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=not-found",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
}}, nil, true},
{"fail name", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "",
}}, nil, true},
{"fail vault", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=;name=not-found?version=my-version",
}}, nil, true},
{"fail id", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=?version=my-version",
}}, nil, true},
{"fail SignatureAlgorithm", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=not-found",
SignatureAlgorithm: apiv1.PureEd25519,
}}, nil, true},
{"fail bit size", fields{client}, args{&apiv1.CreateKeyRequest{
Name: "azurekms:vault=my-vault;name=not-found",
SignatureAlgorithm: apiv1.SHA384WithRSAPSS,
Bits: 1024,
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &KeyVault{
baseClient: tt.fields.baseClient,
}
got, err := k.CreateKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("KeyVault.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("KeyVault.CreateKey() = %v, want %v", got, tt.want)
}
})
}
}
func TestKeyVault_CreateSigner(t *testing.T) {
key, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatal(err)
}
pub := key.Public()
jwk := createJWK(t, pub)
client := mockClient(t)
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
Key: jwk,
}, nil)
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
Key: jwk,
}, nil)
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
type fields struct {
baseClient KeyVaultClient
}
type args struct {
req *apiv1.CreateSignerRequest
}
tests := []struct {
name string
fields fields
args args
want crypto.Signer
wantErr bool
}{
{"ok", fields{client}, args{&apiv1.CreateSignerRequest{
SigningKey: "azurekms:vault=my-vault;name=my-key",
}}, &Signer{
client: client,
vaultBaseURL: "https://my-vault.vault.azure.net/",
name: "my-key",
version: "",
publicKey: pub,
}, false},
{"ok with version", fields{client}, args{&apiv1.CreateSignerRequest{
SigningKey: "azurekms:vault=my-vault;name=my-key;version=my-version",
}}, &Signer{
client: client,
vaultBaseURL: "https://my-vault.vault.azure.net/",
name: "my-key",
version: "my-version",
publicKey: pub,
}, false},
{"fail GetKey", fields{client}, args{&apiv1.CreateSignerRequest{
SigningKey: "azurekms:vault=my-vault;name=not-found;version=my-version",
}}, nil, true},
{"fail SigningKey", fields{client}, args{&apiv1.CreateSignerRequest{
SigningKey: "",
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &KeyVault{
baseClient: tt.fields.baseClient,
}
got, err := k.CreateSigner(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("KeyVault.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("KeyVault.CreateSigner() = %v, want %v", got, tt.want)
}
})
}
}
func TestKeyVault_Close(t *testing.T) {
client := mockClient(t)
type fields struct {
baseClient KeyVaultClient
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{"ok", fields{client}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &KeyVault{
baseClient: tt.fields.baseClient,
}
if err := k.Close(); (err != nil) != tt.wantErr {
t.Errorf("KeyVault.Close() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_keyType_KeyType(t *testing.T) {
type fields struct {
Kty keyvault.JSONWebKeyType
Curve keyvault.JSONWebKeyCurveName
}
type args struct {
pl apiv1.ProtectionLevel
}
tests := []struct {
name string
fields fields
args args
want keyvault.JSONWebKeyType
}{
{"ec", fields{keyvault.EC, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.EC},
{"ec software", fields{keyvault.EC, keyvault.P384}, args{apiv1.Software}, keyvault.EC},
{"ec hsm", fields{keyvault.EC, keyvault.P521}, args{apiv1.HSM}, keyvault.ECHSM},
{"rsa", fields{keyvault.RSA, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.RSA},
{"rsa software", fields{keyvault.RSA, ""}, args{apiv1.Software}, keyvault.RSA},
{"rsa hsm", fields{keyvault.RSA, ""}, args{apiv1.HSM}, keyvault.RSAHSM},
{"empty", fields{"FOO", ""}, args{apiv1.UnspecifiedProtectionLevel}, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := keyType{
Kty: tt.fields.Kty,
Curve: tt.fields.Curve,
}
if got := k.KeyType(tt.args.pl); !reflect.DeepEqual(got, tt.want) {
t.Errorf("keyType.KeyType() = %v, want %v", got, tt.want)
}
})
}
}
func TestKeyVault_ValidateName(t *testing.T) {
type args struct {
s string
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{"azurekms:name=my-key;vault=my-vault"}, false},
{"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true"}, false},
{"fail scheme", args{"azure:name=my-key;vault=my-vault"}, true},
{"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault"}, true},
{"fail no name", args{"azurekms:vault=my-vault"}, true},
{"fail no vault", args{"azurekms:name=my-key"}, true},
{"fail empty", args{""}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &KeyVault{}
if err := k.ValidateName(tt.args.s); (err != nil) != tt.wantErr {
t.Errorf("KeyVault.ValidateName() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -1,182 +0,0 @@
package azurekms
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"encoding/base64"
"io"
"math/big"
"time"
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/pkg/errors"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
// Signer implements a crypto.Signer using the AWS KMS.
type Signer struct {
client KeyVaultClient
vaultBaseURL string
name string
version string
publicKey crypto.PublicKey
}
// NewSigner creates a new signer using a key in the AWS KMS.
func NewSigner(client KeyVaultClient, signingKey string, defaults DefaultOptions) (crypto.Signer, error) {
vault, name, version, _, err := parseKeyName(signingKey, defaults)
if err != nil {
return nil, err
}
// Make sure that the key exists.
signer := &Signer{
client: client,
vaultBaseURL: vaultBaseURL(vault),
name: name,
version: version,
}
if err := signer.preloadKey(); err != nil {
return nil, err
}
return signer, nil
}
func (s *Signer) preloadKey() error {
ctx, cancel := defaultContext()
defer cancel()
resp, err := s.client.GetKey(ctx, s.vaultBaseURL, s.name, s.version)
if err != nil {
return errors.Wrap(err, "keyVault GetKey failed")
}
s.publicKey, err = convertKey(resp.Key)
return err
}
// Public returns the public key of this signer or an error.
func (s *Signer) Public() crypto.PublicKey {
return s.publicKey
}
// Sign signs digest with the private key stored in the AWS KMS.
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
alg, err := getSigningAlgorithm(s.Public(), opts)
if err != nil {
return nil, err
}
b64 := base64.RawURLEncoding.EncodeToString(digest)
// Sign with retry if the key is not ready
resp, err := s.signWithRetry(alg, b64, 3)
if err != nil {
return nil, errors.Wrap(err, "keyVault Sign failed")
}
sig, err := base64.RawURLEncoding.DecodeString(*resp.Result)
if err != nil {
return nil, errors.Wrap(err, "error decoding keyVault Sign result")
}
var octetSize int
switch alg {
case keyvault.ES256:
octetSize = 32 // 256-bit, concat(R,S) = 64 bytes
case keyvault.ES384:
octetSize = 48 // 384-bit, concat(R,S) = 96 bytes
case keyvault.ES512:
octetSize = 66 // 528-bit, concat(R,S) = 132 bytes
default:
return sig, nil
}
// Convert to asn1
if len(sig) != octetSize*2 {
return nil, errors.Errorf("keyVault Sign failed: unexpected signature length")
}
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1BigInt(new(big.Int).SetBytes(sig[:octetSize])) // R
b.AddASN1BigInt(new(big.Int).SetBytes(sig[octetSize:])) // S
})
return b.Bytes()
}
func (s *Signer) signWithRetry(alg keyvault.JSONWebKeySignatureAlgorithm, b64 string, retryAttempts int) (keyvault.KeyOperationResult, error) {
retry:
ctx, cancel := defaultContext()
defer cancel()
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{
Algorithm: alg,
Value: &b64,
})
if err != nil && retryAttempts > 0 {
var requestError *azure.RequestError
if errors.As(err, &requestError) {
if se := requestError.ServiceError; se != nil && se.InnerError != nil {
code, ok := se.InnerError["code"].(string)
if ok && code == "KeyNotYetValid" {
time.Sleep(time.Second / time.Duration(retryAttempts))
retryAttempts--
goto retry
}
}
}
}
return resp, err
}
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) {
switch key.(type) {
case *rsa.PublicKey:
hashFunc := opts.HashFunc()
pss, isPSS := opts.(*rsa.PSSOptions)
// Random salt lengths are not supported
if isPSS &&
pss.SaltLength != rsa.PSSSaltLengthAuto &&
pss.SaltLength != rsa.PSSSaltLengthEqualsHash &&
pss.SaltLength != hashFunc.Size() {
return "", errors.Errorf("unsupported RSA-PSS salt length %d", pss.SaltLength)
}
switch h := hashFunc; h {
case crypto.SHA256:
if isPSS {
return keyvault.PS256, nil
}
return keyvault.RS256, nil
case crypto.SHA384:
if isPSS {
return keyvault.PS384, nil
}
return keyvault.RS384, nil
case crypto.SHA512:
if isPSS {
return keyvault.PS512, nil
}
return keyvault.RS512, nil
default:
return "", errors.Errorf("unsupported hash function %v", h)
}
case *ecdsa.PublicKey:
switch h := opts.HashFunc(); h {
case crypto.SHA256:
return keyvault.ES256, nil
case crypto.SHA384:
return keyvault.ES384, nil
case crypto.SHA512:
return keyvault.ES512, nil
default:
return "", errors.Errorf("unsupported hash function %v", h)
}
default:
return "", errors.Errorf("unsupported key type %T", key)
}
}

View file

@ -1,493 +0,0 @@
package azurekms
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"io"
"reflect"
"testing"
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/golang/mock/gomock"
"github.com/smallstep/certificates/kms/apiv1"
"go.step.sm/crypto/keyutil"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
func TestNewSigner(t *testing.T) {
key, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatal(err)
}
pub := key.Public()
jwk := createJWK(t, pub)
client := mockClient(t)
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
Key: jwk,
}, nil)
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
Key: jwk,
}, nil)
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
Key: jwk,
}, nil)
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
var noOptions DefaultOptions
type args struct {
client KeyVaultClient
signingKey string
defaults DefaultOptions
}
tests := []struct {
name string
args args
want crypto.Signer
wantErr bool
}{
{"ok", args{client, "azurekms:vault=my-vault;name=my-key", noOptions}, &Signer{
client: client,
vaultBaseURL: "https://my-vault.vault.azure.net/",
name: "my-key",
version: "",
publicKey: pub,
}, false},
{"ok with version", args{client, "azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, &Signer{
client: client,
vaultBaseURL: "https://my-vault.vault.azure.net/",
name: "my-key",
version: "my-version",
publicKey: pub,
}, false},
{"ok with options", args{client, "azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault", ProtectionLevel: apiv1.HSM}}, &Signer{
client: client,
vaultBaseURL: "https://my-vault.vault.azure.net/",
name: "my-key",
version: "my-version",
publicKey: pub,
}, false},
{"fail GetKey", args{client, "azurekms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true},
{"fail vault", args{client, "azurekms:name=not-found;vault=", noOptions}, nil, true},
{"fail id", args{client, "azurekms:name=;vault=my-vault?version=my-version", noOptions}, nil, true},
{"fail scheme", args{client, "kms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewSigner(tt.args.client, tt.args.signingKey, tt.args.defaults)
if (err != nil) != tt.wantErr {
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
}
})
}
}
func TestSigner_Public(t *testing.T) {
key, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatal(err)
}
pub := key.Public()
type fields struct {
publicKey crypto.PublicKey
}
tests := []struct {
name string
fields fields
want crypto.PublicKey
}{
{"ok", fields{pub}, pub},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Signer{
publicKey: tt.fields.publicKey,
}
if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Signer.Public() = %v, want %v", got, tt.want)
}
})
}
}
func TestSigner_Sign(t *testing.T) {
sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) {
key, err := keyutil.GenerateSigner(kty, crv, bits)
if err != nil {
t.Fatal(err)
}
h := opts.HashFunc().New()
h.Write([]byte("random-data"))
sum := h.Sum(nil)
var sig, resultSig []byte
if priv, ok := key.(*ecdsa.PrivateKey); ok {
r, s, err := ecdsa.Sign(rand.Reader, priv, sum)
if err != nil {
t.Fatal(err)
}
curveBits := priv.Params().BitSize
keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes++
}
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
// nolint:gocritic
resultSig = append(rBytesPadded, sBytesPadded...)
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1BigInt(r)
b.AddASN1BigInt(s)
})
sig, err = b.Bytes()
if err != nil {
t.Fatal(err)
}
} else {
sig, err = key.Sign(rand.Reader, sum, opts)
if err != nil {
t.Fatal(err)
}
resultSig = sig
}
return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig
}
p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256)
p384, p384Digest, p386ResultSig, p384Sig := sign("EC", "P-384", 0, crypto.SHA384)
p521, p521Digest, p521ResultSig, p521Sig := sign("EC", "P-521", 0, crypto.SHA512)
rsaSHA256, rsaSHA256Digest, rsaSHA256ResultSig, rsaSHA256Sig := sign("RSA", "", 2048, crypto.SHA256)
rsaSHA384, rsaSHA384Digest, rsaSHA384ResultSig, rsaSHA384Sig := sign("RSA", "", 2048, crypto.SHA384)
rsaSHA512, rsaSHA512Digest, rsaSHA512ResultSig, rsaSHA512Sig := sign("RSA", "", 2048, crypto.SHA512)
rsaPSSSHA256, rsaPSSSHA256Digest, rsaPSSSHA256ResultSig, rsaPSSSHA256Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA256,
})
rsaPSSSHA384, rsaPSSSHA384Digest, rsaPSSSHA384ResultSig, rsaPSSSHA384Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA512,
})
rsaPSSSHA512, rsaPSSSHA512Digest, rsaPSSSHA512ResultSig, rsaPSSSHA512Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA512,
})
ed25519Key, err := keyutil.GenerateSigner("OKP", "Ed25519", 0)
if err != nil {
t.Fatal(err)
}
client := mockClient(t)
expects := []struct {
name string
keyVersion string
alg keyvault.JSONWebKeySignatureAlgorithm
digest []byte
result keyvault.KeyOperationResult
err error
}{
{"P-256", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
Result: &p256ResultSig,
}, nil},
{"P-384", "my-version", keyvault.ES384, p384Digest, keyvault.KeyOperationResult{
Result: &p386ResultSig,
}, nil},
{"P-521", "my-version", keyvault.ES512, p521Digest, keyvault.KeyOperationResult{
Result: &p521ResultSig,
}, nil},
{"RSA SHA256", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{
Result: &rsaSHA256ResultSig,
}, nil},
{"RSA SHA384", "", keyvault.RS384, rsaSHA384Digest, keyvault.KeyOperationResult{
Result: &rsaSHA384ResultSig,
}, nil},
{"RSA SHA512", "", keyvault.RS512, rsaSHA512Digest, keyvault.KeyOperationResult{
Result: &rsaSHA512ResultSig,
}, nil},
{"RSA-PSS SHA256", "", keyvault.PS256, rsaPSSSHA256Digest, keyvault.KeyOperationResult{
Result: &rsaPSSSHA256ResultSig,
}, nil},
{"RSA-PSS SHA384", "", keyvault.PS384, rsaPSSSHA384Digest, keyvault.KeyOperationResult{
Result: &rsaPSSSHA384ResultSig,
}, nil},
{"RSA-PSS SHA512", "", keyvault.PS512, rsaPSSSHA512Digest, keyvault.KeyOperationResult{
Result: &rsaPSSSHA512ResultSig,
}, nil},
// Errors
{"fail Sign", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{}, errTest},
{"fail sign length", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
Result: &rsaSHA256ResultSig,
}, nil},
{"fail base64", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
Result: func() *string {
v := "😎"
return &v
}(),
}, nil},
}
for _, e := range expects {
value := base64.RawURLEncoding.EncodeToString(e.digest)
client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{
Algorithm: e.alg,
Value: &value,
}).Return(e.result, e.err)
}
type fields struct {
client KeyVaultClient
vaultBaseURL string
name string
version string
publicKey crypto.PublicKey
}
type args struct {
rand io.Reader
digest []byte
opts crypto.SignerOpts
}
tests := []struct {
name string
fields fields
args args
want []byte
wantErr bool
}{
{"ok P-256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
rand.Reader, p256Digest, crypto.SHA256,
}, p256Sig, false},
{"ok P-384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p384}, args{
rand.Reader, p384Digest, crypto.SHA384,
}, p384Sig, false},
{"ok P-521", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p521}, args{
rand.Reader, p521Digest, crypto.SHA512,
}, p521Sig, false},
{"ok RSA SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
rand.Reader, rsaSHA256Digest, crypto.SHA256,
}, rsaSHA256Sig, false},
{"ok RSA SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA384}, args{
rand.Reader, rsaSHA384Digest, crypto.SHA384,
}, rsaSHA384Sig, false},
{"ok RSA SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA512}, args{
rand.Reader, rsaSHA512Digest, crypto.SHA512,
}, rsaSHA512Sig, false},
{"ok RSA-PSS SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{
rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA256,
},
}, rsaPSSSHA256Sig, false},
{"ok RSA-PSS SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA384}, args{
rand.Reader, rsaPSSSHA384Digest, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
Hash: crypto.SHA384,
},
}, rsaPSSSHA384Sig, false},
{"ok RSA-PSS SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA512}, args{
rand.Reader, rsaPSSSHA512Digest, &rsa.PSSOptions{
SaltLength: 64,
Hash: crypto.SHA512,
},
}, rsaPSSSHA512Sig, false},
{"fail Sign", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
rand.Reader, rsaSHA256Digest, crypto.SHA256,
}, nil, true},
{"fail sign length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
rand.Reader, p256Digest, crypto.SHA256,
}, nil, true},
{"fail base64", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
rand.Reader, p256Digest, crypto.SHA256,
}, nil, true},
{"fail RSA-PSS salt length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{
rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{
SaltLength: 64,
Hash: crypto.SHA256,
},
}, nil, true},
{"fail RSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
rand.Reader, rsaSHA256Digest, crypto.SHA1,
}, nil, true},
{"fail ECDSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
rand.Reader, p256Digest, crypto.MD5,
}, nil, true},
{"fail Ed25519", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", ed25519Key}, args{
rand.Reader, []byte("message"), crypto.Hash(0),
}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Signer{
client: tt.fields.client,
vaultBaseURL: tt.fields.vaultBaseURL,
name: tt.fields.name,
version: tt.fields.version,
publicKey: tt.fields.publicKey,
}
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
}
})
}
}
func TestSigner_Sign_signWithRetry(t *testing.T) {
sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) {
key, err := keyutil.GenerateSigner(kty, crv, bits)
if err != nil {
t.Fatal(err)
}
h := opts.HashFunc().New()
h.Write([]byte("random-data"))
sum := h.Sum(nil)
var sig, resultSig []byte
if priv, ok := key.(*ecdsa.PrivateKey); ok {
r, s, err := ecdsa.Sign(rand.Reader, priv, sum)
if err != nil {
t.Fatal(err)
}
curveBits := priv.Params().BitSize
keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes++
}
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
// nolint:gocritic
resultSig = append(rBytesPadded, sBytesPadded...)
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1BigInt(r)
b.AddASN1BigInt(s)
})
sig, err = b.Bytes()
if err != nil {
t.Fatal(err)
}
} else {
sig, err = key.Sign(rand.Reader, sum, opts)
if err != nil {
t.Fatal(err)
}
resultSig = sig
}
return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig
}
p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256)
okResult := keyvault.KeyOperationResult{
Result: &p256ResultSig,
}
failResult := keyvault.KeyOperationResult{}
retryError := autorest.DetailedError{
Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
InnerError: map[string]interface{}{
"code": "KeyNotYetValid",
},
},
},
}
client := mockClient(t)
expects := []struct {
name string
keyVersion string
alg keyvault.JSONWebKeySignatureAlgorithm
digest []byte
result keyvault.KeyOperationResult
err error
}{
{"ok 1", "", keyvault.ES256, p256Digest, failResult, retryError},
{"ok 2", "", keyvault.ES256, p256Digest, failResult, retryError},
{"ok 3", "", keyvault.ES256, p256Digest, failResult, retryError},
{"ok 4", "", keyvault.ES256, p256Digest, okResult, nil},
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
}
for _, e := range expects {
value := base64.RawURLEncoding.EncodeToString(e.digest)
client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{
Algorithm: e.alg,
Value: &value,
}).Return(e.result, e.err)
}
type fields struct {
client KeyVaultClient
vaultBaseURL string
name string
version string
publicKey crypto.PublicKey
}
type args struct {
rand io.Reader
digest []byte
opts crypto.SignerOpts
}
tests := []struct {
name string
fields fields
args args
want []byte
wantErr bool
}{
{"ok", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
rand.Reader, p256Digest, crypto.SHA256,
}, p256Sig, false},
{"fail", fields{client, "https://my-vault.vault.azure.net/", "my-key", "fail-version", p256}, args{
rand.Reader, p256Digest, crypto.SHA256,
}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Signer{
client: tt.fields.client,
vaultBaseURL: tt.fields.vaultBaseURL,
name: tt.fields.name,
version: tt.fields.version,
publicKey: tt.fields.publicKey,
}
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,98 +0,0 @@
package azurekms
import (
"context"
"crypto"
"encoding/json"
"net/url"
"time"
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/kms/uri"
"go.step.sm/crypto/jose"
)
// defaultContext returns the default context used in requests to azure.
func defaultContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 15*time.Second)
}
// getKeyName returns the uri of the key vault key.
func getKeyName(vault, name string, bundle keyvault.KeyBundle) string {
if bundle.Key != nil && bundle.Key.Kid != nil {
sm := keyIDRegexp.FindAllStringSubmatch(*bundle.Key.Kid, 1)
if len(sm) == 1 && len(sm[0]) == 4 {
m := sm[0]
u := uri.New(Scheme, url.Values{
"vault": []string{m[1]},
"name": []string{m[2]},
})
u.RawQuery = url.Values{"version": []string{m[3]}}.Encode()
return u.String()
}
}
// Fallback to URI without id.
return uri.New(Scheme, url.Values{
"vault": []string{vault},
"name": []string{name},
}).String()
}
// parseKeyName returns the key vault, name and version from URIs like:
//
// - azurekms:vault=key-vault;name=key-name
// - azurekms:vault=key-vault;name=key-name?version=key-id
// - azurekms:vault=key-vault;name=key-name?version=key-id&hsm=true
//
// The key-id defines the version of the key, if it is not passed the latest
// version will be used.
//
// HSM can also be passed to define the protection level if this is not given in
// CreateQuery.
func parseKeyName(rawURI string, defaults DefaultOptions) (vault, name, version string, hsm bool, err error) {
var u *uri.URI
u, err = uri.ParseWithScheme(Scheme, rawURI)
if err != nil {
return
}
if name = u.Get("name"); name == "" {
err = errors.Errorf("key uri %s is not valid: name is missing", rawURI)
return
}
if vault = u.Get("vault"); vault == "" {
if defaults.Vault == "" {
name = ""
err = errors.Errorf("key uri %s is not valid: vault is missing", rawURI)
return
}
vault = defaults.Vault
}
if u.Get("hsm") == "" {
hsm = (defaults.ProtectionLevel == apiv1.HSM)
} else {
hsm = u.GetBool("hsm")
}
version = u.Get("version")
return
}
func vaultBaseURL(vault string) string {
return "https://" + vault + ".vault.azure.net/"
}
func convertKey(key *keyvault.JSONWebKey) (crypto.PublicKey, error) {
b, err := json.Marshal(key)
if err != nil {
return nil, errors.Wrap(err, "error marshaling key")
}
var jwk jose.JSONWebKey
if err := jwk.UnmarshalJSON(b); err != nil {
return nil, errors.Wrap(err, "error unmarshaling key")
}
return jwk.Key, nil
}

View file

@ -1,96 +0,0 @@
package azurekms
import (
"testing"
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/smallstep/certificates/kms/apiv1"
)
func Test_getKeyName(t *testing.T) {
getBundle := func(kid string) keyvault.KeyBundle {
return keyvault.KeyBundle{
Key: &keyvault.JSONWebKey{
Kid: &kid,
},
}
}
type args struct {
vault string
name string
bundle keyvault.KeyBundle
}
tests := []struct {
name string
args args
want string
}{
{"ok", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault?version=my-version"},
{"ok default", args{"my-vault", "my-key", getBundle("https://my-vault.foo.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault"},
{"ok too short", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-version")}, "azurekms:name=my-key;vault=my-vault"},
{"ok too long", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version/sign")}, "azurekms:name=my-key;vault=my-vault"},
{"ok nil key", args{"my-vault", "my-key", keyvault.KeyBundle{}}, "azurekms:name=my-key;vault=my-vault"},
{"ok nil kid", args{"my-vault", "my-key", keyvault.KeyBundle{Key: &keyvault.JSONWebKey{}}}, "azurekms:name=my-key;vault=my-vault"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getKeyName(tt.args.vault, tt.args.name, tt.args.bundle); got != tt.want {
t.Errorf("getKeyName() = %v, want %v", got, tt.want)
}
})
}
}
func Test_parseKeyName(t *testing.T) {
var noOptions DefaultOptions
type args struct {
rawURI string
defaults DefaultOptions
}
tests := []struct {
name string
args args
wantVault string
wantName string
wantVersion string
wantHsm bool
wantErr bool
}{
{"ok", args{"azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false},
{"ok opaque version", args{"azurekms:name=my-key;vault=my-vault;version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false},
{"ok no version", args{"azurekms:name=my-key;vault=my-vault", noOptions}, "my-vault", "my-key", "", false, false},
{"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true", noOptions}, "my-vault", "my-key", "", true, false},
{"ok hsm false", args{"azurekms:name=my-key;vault=my-vault?hsm=false", noOptions}, "my-vault", "my-key", "", false, false},
{"ok default vault", args{"azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault"}}, "my-vault", "my-key", "my-version", false, false},
{"ok default hsm", args{"azurekms:name=my-key;vault=my-vault?version=my-version", DefaultOptions{Vault: "other-vault", ProtectionLevel: apiv1.HSM}}, "my-vault", "my-key", "my-version", true, false},
{"fail scheme", args{"azure:name=my-key;vault=my-vault", noOptions}, "", "", "", false, true},
{"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault", noOptions}, "", "", "", false, true},
{"fail no name", args{"azurekms:vault=my-vault", noOptions}, "", "", "", false, true},
{"fail empty name", args{"azurekms:name=;vault=my-vault", noOptions}, "", "", "", false, true},
{"fail no vault", args{"azurekms:name=my-key", noOptions}, "", "", "", false, true},
{"fail empty vault", args{"azurekms:name=my-key;vault=", noOptions}, "", "", "", false, true},
{"fail empty", args{"", noOptions}, "", "", "", false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotVault, gotName, gotVersion, gotHsm, err := parseKeyName(tt.args.rawURI, tt.args.defaults)
if (err != nil) != tt.wantErr {
t.Errorf("parseKeyName() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotVault != tt.wantVault {
t.Errorf("parseKeyName() gotVault = %v, want %v", gotVault, tt.wantVault)
}
if gotName != tt.wantName {
t.Errorf("parseKeyName() gotName = %v, want %v", gotName, tt.wantName)
}
if gotVersion != tt.wantVersion {
t.Errorf("parseKeyName() gotVersion = %v, want %v", gotVersion, tt.wantVersion)
}
if gotHsm != tt.wantHsm {
t.Errorf("parseKeyName() gotHsm = %v, want %v", gotHsm, tt.wantHsm)
}
})
}
}

View file

@ -1,348 +0,0 @@
package cloudkms
import (
"context"
"crypto"
"crypto/x509"
"log"
"strings"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
cloudkms "cloud.google.com/go/kms/apiv1"
gax "github.com/googleapis/gax-go/v2"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/kms/uri"
"go.step.sm/crypto/pemutil"
"google.golang.org/api/option"
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
)
// Scheme is the scheme used in uris.
const Scheme = "cloudkms"
const pendingGenerationRetries = 10
// protectionLevelMapping maps step protection levels with cloud kms ones.
var protectionLevelMapping = map[apiv1.ProtectionLevel]kmspb.ProtectionLevel{
apiv1.UnspecifiedProtectionLevel: kmspb.ProtectionLevel_PROTECTION_LEVEL_UNSPECIFIED,
apiv1.Software: kmspb.ProtectionLevel_SOFTWARE,
apiv1.HSM: kmspb.ProtectionLevel_HSM,
}
// signatureAlgorithmMapping is a mapping between the step signature algorithm,
// and bits for RSA keys, with cloud kms one.
//
// Cloud KMS does not support SHA384WithRSA, SHA384WithRSAPSS, SHA384WithRSAPSS,
// ECDSAWithSHA512, and PureEd25519.
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]interface{}{
apiv1.UnspecifiedSignAlgorithm: kmspb.CryptoKeyVersion_CRYPTO_KEY_VERSION_ALGORITHM_UNSPECIFIED,
apiv1.SHA256WithRSA: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
0: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256,
2048: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256,
3072: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256,
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256,
},
apiv1.SHA512WithRSA: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
0: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512,
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512,
},
apiv1.SHA256WithRSAPSS: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
0: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256,
2048: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256,
3072: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256,
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256,
},
apiv1.SHA512WithRSAPSS: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
0: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512,
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512,
},
apiv1.ECDSAWithSHA256: kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256,
apiv1.ECDSAWithSHA384: kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384,
}
var cryptoKeyVersionMapping = map[kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm]x509.SignatureAlgorithm{
kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256: x509.ECDSAWithSHA256,
kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384: x509.ECDSAWithSHA384,
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256: x509.SHA256WithRSA,
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256: x509.SHA256WithRSA,
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256: x509.SHA256WithRSA,
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512: x509.SHA512WithRSA,
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256: x509.SHA256WithRSAPSS,
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256: x509.SHA256WithRSAPSS,
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256: x509.SHA256WithRSAPSS,
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512: x509.SHA512WithRSAPSS,
}
// KeyManagementClient defines the methods on KeyManagementClient that this
// package will use. This interface will be used for unit testing.
type KeyManagementClient interface {
Close() error
GetPublicKey(context.Context, *kmspb.GetPublicKeyRequest, ...gax.CallOption) (*kmspb.PublicKey, error)
AsymmetricSign(context.Context, *kmspb.AsymmetricSignRequest, ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error)
CreateCryptoKey(context.Context, *kmspb.CreateCryptoKeyRequest, ...gax.CallOption) (*kmspb.CryptoKey, error)
GetKeyRing(context.Context, *kmspb.GetKeyRingRequest, ...gax.CallOption) (*kmspb.KeyRing, error)
CreateKeyRing(context.Context, *kmspb.CreateKeyRingRequest, ...gax.CallOption) (*kmspb.KeyRing, error)
CreateCryptoKeyVersion(ctx context.Context, req *kmspb.CreateCryptoKeyVersionRequest, opts ...gax.CallOption) (*kmspb.CryptoKeyVersion, error)
}
var newKeyManagementClient = func(ctx context.Context, opts ...option.ClientOption) (KeyManagementClient, error) {
return cloudkms.NewKeyManagementClient(ctx, opts...)
}
// CloudKMS implements a KMS using Google's Cloud apiv1.
type CloudKMS struct {
client KeyManagementClient
}
// New creates a new CloudKMS configured with a new client.
func New(ctx context.Context, opts apiv1.Options) (*CloudKMS, error) {
var cloudOpts []option.ClientOption
if opts.URI != "" {
u, err := uri.ParseWithScheme(Scheme, opts.URI)
if err != nil {
return nil, err
}
if f := u.Get("credentials-file"); f != "" {
cloudOpts = append(cloudOpts, option.WithCredentialsFile(f))
}
}
// Deprecated way to set configuration parameters.
if opts.CredentialsFile != "" {
cloudOpts = append(cloudOpts, option.WithCredentialsFile(opts.CredentialsFile))
}
client, err := newKeyManagementClient(ctx, cloudOpts...)
if err != nil {
return nil, err
}
return &CloudKMS{
client: client,
}, nil
}
func init() {
apiv1.Register(apiv1.CloudKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
return New(ctx, opts)
})
}
// NewCloudKMS creates a CloudKMS with a given client.
func NewCloudKMS(client KeyManagementClient) *CloudKMS {
return &CloudKMS{
client: client,
}
}
// Close closes the connection of the Cloud KMS client.
func (k *CloudKMS) Close() error {
if err := k.client.Close(); err != nil {
return errors.Wrap(err, "cloudKMS Close failed")
}
return nil
}
// CreateSigner returns a new cloudkms signer configured with the given signing
// key name.
func (k *CloudKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
if req.SigningKey == "" {
return nil, errors.New("signing key cannot be empty")
}
return NewSigner(k.client, req.SigningKey)
}
// CreateKey creates in Google's Cloud KMS a new asymmetric key for signing.
func (k *CloudKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
if req.Name == "" {
return nil, errors.New("createKeyRequest 'name' cannot be empty")
}
protectionLevel, ok := protectionLevelMapping[req.ProtectionLevel]
if !ok {
return nil, errors.Errorf("cloudKMS does not support protection level '%s'", req.ProtectionLevel)
}
var signatureAlgorithm kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm
v, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
if !ok {
return nil, errors.Errorf("cloudKMS does not support signature algorithm '%s'", req.SignatureAlgorithm)
}
switch v := v.(type) {
case kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm:
signatureAlgorithm = v
case map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm:
if signatureAlgorithm, ok = v[req.Bits]; !ok {
return nil, errors.Errorf("cloudKMS does not support signature algorithm '%s' with '%d' bits", req.SignatureAlgorithm, req.Bits)
}
default:
return nil, errors.Errorf("unexpected error: this should not happen")
}
var crytoKeyName string
// Split `projects/PROJECT_ID/locations/global/keyRings/RING_ID/cryptoKeys/KEY_ID`
// to `projects/PROJECT_ID/locations/global/keyRings/RING_ID` and `KEY_ID`.
keyRing, keyID := Parent(req.Name)
if err := k.createKeyRingIfNeeded(keyRing); err != nil {
return nil, err
}
ctx, cancel := defaultContext()
defer cancel()
// Create private key in CloudKMS.
response, err := k.client.CreateCryptoKey(ctx, &kmspb.CreateCryptoKeyRequest{
Parent: keyRing,
CryptoKeyId: keyID,
CryptoKey: &kmspb.CryptoKey{
Purpose: kmspb.CryptoKey_ASYMMETRIC_SIGN,
VersionTemplate: &kmspb.CryptoKeyVersionTemplate{
ProtectionLevel: protectionLevel,
Algorithm: signatureAlgorithm,
},
},
})
if err != nil {
if status.Code(err) != codes.AlreadyExists {
return nil, errors.Wrap(err, "cloudKMS CreateCryptoKey failed")
}
// Create a new version if the key already exists.
//
// Note that it will have the same purpose, protection level and
// algorithm than as previous one.
req := &kmspb.CreateCryptoKeyVersionRequest{
Parent: req.Name,
CryptoKeyVersion: &kmspb.CryptoKeyVersion{
State: kmspb.CryptoKeyVersion_ENABLED,
},
}
response, err := k.client.CreateCryptoKeyVersion(ctx, req)
if err != nil {
return nil, errors.Wrap(err, "cloudKMS CreateCryptoKeyVersion failed")
}
crytoKeyName = response.Name
} else {
crytoKeyName = response.Name + "/cryptoKeyVersions/1"
}
// Sleep deterministically to avoid retries because of PENDING_GENERATING.
// One second is often enough.
if protectionLevel == kmspb.ProtectionLevel_HSM {
time.Sleep(1 * time.Second)
}
// Retrieve public key to add it to the response.
pk, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
Name: crytoKeyName,
})
if err != nil {
return nil, errors.Wrap(err, "cloudKMS GetPublicKey failed")
}
return &apiv1.CreateKeyResponse{
Name: crytoKeyName,
PublicKey: pk,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: crytoKeyName,
},
}, nil
}
func (k *CloudKMS) createKeyRingIfNeeded(name string) error {
ctx, cancel := defaultContext()
defer cancel()
_, err := k.client.GetKeyRing(ctx, &kmspb.GetKeyRingRequest{
Name: name,
})
if err == nil {
return nil
}
parent, child := Parent(name)
_, err = k.client.CreateKeyRing(ctx, &kmspb.CreateKeyRingRequest{
Parent: parent,
KeyRingId: child,
})
if err != nil && status.Code(err) != codes.AlreadyExists {
return errors.Wrap(err, "cloudKMS CreateKeyRing failed")
}
return nil
}
// GetPublicKey gets from Google's Cloud KMS a public key by name. Key names
// follow the pattern:
//
// projects/([^/]+)/locations/([a-zA-Z0-9_-]{1,63})/keyRings/([a-zA-Z0-9_-]{1,63})/cryptoKeys/([a-zA-Z0-9_-]{1,63})/cryptoKeyVersions/([a-zA-Z0-9_-]{1,63})
func (k *CloudKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
if req.Name == "" {
return nil, errors.New("createKeyRequest 'name' cannot be empty")
}
response, err := k.getPublicKeyWithRetries(req.Name, pendingGenerationRetries)
if err != nil {
return nil, errors.Wrap(err, "cloudKMS GetPublicKey failed")
}
pk, err := pemutil.ParseKey([]byte(response.Pem))
if err != nil {
return nil, err
}
return pk, nil
}
// getPublicKeyWithRetries retries the request if the error is
// FailedPrecondition, caused because the key is in the PENDING_GENERATION
// status.
func (k *CloudKMS) getPublicKeyWithRetries(name string, retries int) (response *kmspb.PublicKey, err error) {
workFn := func() (*kmspb.PublicKey, error) {
ctx, cancel := defaultContext()
defer cancel()
return k.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{
Name: name,
})
}
for i := 0; i < retries; i++ {
if response, err = workFn(); err == nil {
return
}
if status.Code(err) == codes.FailedPrecondition {
log.Println("Waiting for key generation ...")
time.Sleep(time.Duration(i+1) * time.Second)
continue
}
}
return
}
func defaultContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 15*time.Second)
}
// Parent splits a string in the format `key/value/key2/value2` in a parent and
// child, for the previous string it will return `key/value` and `value2`.
func Parent(name string) (string, string) {
a, b := parent(name)
a, _ = parent(a)
return a, b
}
func parent(name string) (string, string) {
i := strings.LastIndex(name, "/")
switch i {
case -1:
return "", name
case 0:
return "", name[i+1:]
default:
return name[:i], name[i+1:]
}
}

View file

@ -1,464 +0,0 @@
package cloudkms
import (
"context"
"crypto"
"fmt"
"os"
"reflect"
"testing"
gax "github.com/googleapis/gax-go/v2"
"github.com/smallstep/certificates/kms/apiv1"
"go.step.sm/crypto/pemutil"
"google.golang.org/api/option"
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestParent(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
want string
want1 string
}{
{"zero", args{"child"}, "", "child"},
{"one", args{"parent/child"}, "", "child"},
{"two", args{"grandparent/parent/child"}, "grandparent", "child"},
{"three", args{"great-grandparent/grandparent/parent/child"}, "great-grandparent/grandparent", "child"},
{"empty", args{""}, "", ""},
{"root", args{"/"}, "", ""},
{"child", args{"/child"}, "", "child"},
{"parent", args{"parent/"}, "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := Parent(tt.args.name)
if got != tt.want {
t.Errorf("Parent() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Parent() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestNew(t *testing.T) {
tmp := newKeyManagementClient
t.Cleanup(func() {
newKeyManagementClient = tmp
})
newKeyManagementClient = func(ctx context.Context, opts ...option.ClientOption) (KeyManagementClient, error) {
if len(opts) > 0 {
return nil, fmt.Errorf("test error")
}
return &MockClient{}, nil
}
type args struct {
ctx context.Context
opts apiv1.Options
}
tests := []struct {
name string
args args
want *CloudKMS
wantErr bool
}{
{"ok", args{context.Background(), apiv1.Options{}}, &CloudKMS{client: &MockClient{}}, false},
{"ok with uri", args{context.Background(), apiv1.Options{URI: "cloudkms:"}}, &CloudKMS{client: &MockClient{}}, false},
{"fail credentials", args{context.Background(), apiv1.Options{CredentialsFile: "testdata/missing"}}, nil, true},
{"fail with uri", args{context.Background(), apiv1.Options{URI: "cloudkms:credentials-file=testdata/missing"}}, nil, true},
{"fail schema", args{context.Background(), apiv1.Options{URI: "pkcs11:"}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := New(tt.args.ctx, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("New() = %v, want %v", got, tt.want)
}
})
}
}
func TestNew_real(t *testing.T) {
type args struct {
ctx context.Context
opts apiv1.Options
}
tests := []struct {
name string
args args
want *CloudKMS
wantErr bool
}{
{"fail credentials", args{context.Background(), apiv1.Options{CredentialsFile: "testdata/missing"}}, nil, true},
{"fail with uri", args{context.Background(), apiv1.Options{URI: "cloudkms:credentials-file=testdata/missing"}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := New(tt.args.ctx, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("New() = %v, want %v", got, tt.want)
}
})
}
}
func TestNewCloudKMS(t *testing.T) {
type args struct {
client KeyManagementClient
}
tests := []struct {
name string
args args
want *CloudKMS
}{
{"ok", args{&MockClient{}}, &CloudKMS{&MockClient{}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewCloudKMS(tt.args.client); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewCloudKMS() = %v, want %v", got, tt.want)
}
})
}
}
func TestCloudKMS_Close(t *testing.T) {
type fields struct {
client KeyManagementClient
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{"ok", fields{&MockClient{close: func() error { return nil }}}, false},
{"fail", fields{&MockClient{close: func() error { return fmt.Errorf("an error") }}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &CloudKMS{
client: tt.fields.client,
}
if err := k.Close(); (err != nil) != tt.wantErr {
t.Errorf("CloudKMS.Close() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestCloudKMS_CreateSigner(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
pk, err := pemutil.ParseKey(pemBytes)
if err != nil {
t.Fatal(err)
}
type fields struct {
client KeyManagementClient
}
type args struct {
req *apiv1.CreateSignerRequest
}
tests := []struct {
name string
fields fields
args args
want crypto.Signer
wantErr bool
}{
{"ok", fields{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}}, args{&apiv1.CreateSignerRequest{SigningKey: keyName}}, &Signer{client: &MockClient{}, signingKey: keyName, publicKey: pk}, false},
{"fail", fields{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return nil, fmt.Errorf("test error")
},
}}, args{&apiv1.CreateSignerRequest{SigningKey: ""}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &CloudKMS{
client: tt.fields.client,
}
got, err := k.CreateSigner(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("CloudKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
return
}
if signer, ok := got.(*Signer); ok {
signer.client = &MockClient{}
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CloudKMS.CreateSigner() = %v, want %v", got, tt.want)
}
})
}
}
func TestCloudKMS_CreateKey(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c"
testError := fmt.Errorf("an error")
alreadyExists := status.Error(codes.AlreadyExists, "already exists")
pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
pk, err := pemutil.ParseKey(pemBytes)
if err != nil {
t.Fatal(err)
}
var retries int
type fields struct {
client KeyManagementClient
}
type args struct {
req *apiv1.CreateKeyRequest
}
tests := []struct {
name string
fields fields
args args
want *apiv1.CreateKeyResponse
wantErr bool
}{
{"ok", fields{
&MockClient{
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
return &kmspb.KeyRing{}, nil
},
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
return &kmspb.CryptoKey{Name: keyName}, nil
},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/1", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/1"}}, false},
{"ok new key ring", fields{
&MockClient{
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
return nil, testError
},
createKeyRing: func(_ context.Context, _ *kmspb.CreateKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
return nil, alreadyExists
},
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
return &kmspb.CryptoKey{Name: keyName}, nil
},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.Software, SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 3072}},
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/1", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/1"}}, false},
{"ok new key version", fields{
&MockClient{
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
return &kmspb.KeyRing{}, nil
},
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
return nil, alreadyExists
},
createCryptoKeyVersion: func(_ context.Context, _ *kmspb.CreateCryptoKeyVersionRequest, _ ...gax.CallOption) (*kmspb.CryptoKeyVersion, error) {
return &kmspb.CryptoKeyVersion{Name: keyName + "/cryptoKeyVersions/2"}, nil
},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/2", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/2"}}, false},
{"ok with retries", fields{
&MockClient{
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
return &kmspb.KeyRing{}, nil
},
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
return &kmspb.CryptoKey{Name: keyName}, nil
},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
if retries != 2 {
retries++
return nil, status.Error(codes.FailedPrecondition, "key is not enabled, current state is: PENDING_GENERATION")
}
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/1", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/1"}}, false},
{"fail name", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{}}, nil, true},
{"fail protection level", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.ProtectionLevel(100)}}, nil, true},
{"fail signature algorithm", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.Software, SignatureAlgorithm: apiv1.SignatureAlgorithm(100)}}, nil, true},
{"fail number of bits", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.Software, SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 1024}},
nil, true},
{"fail create key ring", fields{
&MockClient{
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
return nil, testError
},
createKeyRing: func(_ context.Context, _ *kmspb.CreateKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
return nil, testError
},
}},
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
nil, true},
{"fail create key", fields{
&MockClient{
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
return &kmspb.KeyRing{}, nil
},
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
return nil, testError
},
}},
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
nil, true},
{"fail create key version", fields{
&MockClient{
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
return &kmspb.KeyRing{}, nil
},
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
return nil, alreadyExists
},
createCryptoKeyVersion: func(_ context.Context, _ *kmspb.CreateCryptoKeyVersionRequest, _ ...gax.CallOption) (*kmspb.CryptoKeyVersion, error) {
return nil, testError
},
}},
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
nil, true},
{"fail get public key", fields{
&MockClient{
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
return &kmspb.KeyRing{}, nil
},
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
return &kmspb.CryptoKey{Name: keyName}, nil
},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return nil, testError
},
}},
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &CloudKMS{
client: tt.fields.client,
}
got, err := k.CreateKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("CloudKMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CloudKMS.CreateKey() = %v, want %v", got, tt.want)
}
})
}
}
func TestCloudKMS_GetPublicKey(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
testError := fmt.Errorf("an error")
pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
pk, err := pemutil.ParseKey(pemBytes)
if err != nil {
t.Fatal(err)
}
var retries int
type fields struct {
client KeyManagementClient
}
type args struct {
req *apiv1.GetPublicKeyRequest
}
tests := []struct {
name string
fields fields
args args
want crypto.PublicKey
wantErr bool
}{
{"ok", fields{
&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, pk, false},
{"ok with retries", fields{
&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
if retries != 2 {
retries++
return nil, status.Error(codes.FailedPrecondition, "key is not enabled, current state is: PENDING_GENERATION")
}
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, pk, false},
{"fail name", fields{&MockClient{}}, args{&apiv1.GetPublicKeyRequest{}}, nil, true},
{"fail get public key", fields{
&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return nil, testError
},
}},
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, nil, true},
{"fail parse pem", fields{
&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string("bad pem")}, nil
},
}},
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &CloudKMS{
client: tt.fields.client,
}
got, err := k.GetPublicKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("CloudKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CloudKMS.GetPublicKey() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,46 +0,0 @@
package cloudkms
import (
"context"
gax "github.com/googleapis/gax-go/v2"
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
)
type MockClient struct {
close func() error
getPublicKey func(context.Context, *kmspb.GetPublicKeyRequest, ...gax.CallOption) (*kmspb.PublicKey, error)
asymmetricSign func(context.Context, *kmspb.AsymmetricSignRequest, ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error)
createCryptoKey func(context.Context, *kmspb.CreateCryptoKeyRequest, ...gax.CallOption) (*kmspb.CryptoKey, error)
getKeyRing func(context.Context, *kmspb.GetKeyRingRequest, ...gax.CallOption) (*kmspb.KeyRing, error)
createKeyRing func(context.Context, *kmspb.CreateKeyRingRequest, ...gax.CallOption) (*kmspb.KeyRing, error)
createCryptoKeyVersion func(context.Context, *kmspb.CreateCryptoKeyVersionRequest, ...gax.CallOption) (*kmspb.CryptoKeyVersion, error)
}
func (m *MockClient) Close() error {
return m.close()
}
func (m *MockClient) GetPublicKey(ctx context.Context, req *kmspb.GetPublicKeyRequest, opts ...gax.CallOption) (*kmspb.PublicKey, error) {
return m.getPublicKey(ctx, req, opts...)
}
func (m *MockClient) AsymmetricSign(ctx context.Context, req *kmspb.AsymmetricSignRequest, opts ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error) {
return m.asymmetricSign(ctx, req, opts...)
}
func (m *MockClient) CreateCryptoKey(ctx context.Context, req *kmspb.CreateCryptoKeyRequest, opts ...gax.CallOption) (*kmspb.CryptoKey, error) {
return m.createCryptoKey(ctx, req, opts...)
}
func (m *MockClient) GetKeyRing(ctx context.Context, req *kmspb.GetKeyRingRequest, opts ...gax.CallOption) (*kmspb.KeyRing, error) {
return m.getKeyRing(ctx, req, opts...)
}
func (m *MockClient) CreateKeyRing(ctx context.Context, req *kmspb.CreateKeyRingRequest, opts ...gax.CallOption) (*kmspb.KeyRing, error) {
return m.createKeyRing(ctx, req, opts...)
}
func (m *MockClient) CreateCryptoKeyVersion(ctx context.Context, req *kmspb.CreateCryptoKeyVersionRequest, opts ...gax.CallOption) (*kmspb.CryptoKeyVersion, error) {
return m.createCryptoKeyVersion(ctx, req, opts...)
}

View file

@ -1,95 +0,0 @@
package cloudkms
import (
"crypto"
"crypto/x509"
"io"
"github.com/pkg/errors"
"go.step.sm/crypto/pemutil"
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
)
// Signer implements a crypto.Signer using Google's Cloud KMS.
type Signer struct {
client KeyManagementClient
signingKey string
algorithm x509.SignatureAlgorithm
publicKey crypto.PublicKey
}
// NewSigner creates a new crypto.Signer the given CloudKMS signing key.
func NewSigner(c KeyManagementClient, signingKey string) (*Signer, error) {
// Make sure that the key exists.
signer := &Signer{
client: c,
signingKey: signingKey,
}
if err := signer.preloadKey(signingKey); err != nil {
return nil, err
}
return signer, nil
}
func (s *Signer) preloadKey(signingKey string) error {
ctx, cancel := defaultContext()
defer cancel()
response, err := s.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{
Name: signingKey,
})
if err != nil {
return errors.Wrap(err, "cloudKMS GetPublicKey failed")
}
s.algorithm = cryptoKeyVersionMapping[response.Algorithm]
s.publicKey, err = pemutil.ParseKey([]byte(response.Pem))
return err
}
// Public returns the public key of this signer or an error.
func (s *Signer) Public() crypto.PublicKey {
return s.publicKey
}
// Sign signs digest with the private key stored in Google's Cloud KMS.
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
req := &kmspb.AsymmetricSignRequest{
Name: s.signingKey,
Digest: &kmspb.Digest{},
}
switch h := opts.HashFunc(); h {
case crypto.SHA256:
req.Digest.Digest = &kmspb.Digest_Sha256{
Sha256: digest,
}
case crypto.SHA384:
req.Digest.Digest = &kmspb.Digest_Sha384{
Sha384: digest,
}
case crypto.SHA512:
req.Digest.Digest = &kmspb.Digest_Sha512{
Sha512: digest,
}
default:
return nil, errors.Errorf("unsupported hash function %v", h)
}
ctx, cancel := defaultContext()
defer cancel()
response, err := s.client.AsymmetricSign(ctx, req)
if err != nil {
return nil, errors.Wrap(err, "cloudKMS AsymmetricSign failed")
}
return response.Signature, nil
}
// SignatureAlgorithm returns the algorithm that must be specified in a
// certificate to sign. This is specially important to distinguish RSA and
// RSAPSS schemas.
func (s *Signer) SignatureAlgorithm() x509.SignatureAlgorithm {
return s.algorithm
}

View file

@ -1,235 +0,0 @@
package cloudkms
import (
"context"
"crypto"
"crypto/rand"
"crypto/x509"
"fmt"
"io"
"os"
"reflect"
"testing"
gax "github.com/googleapis/gax-go/v2"
"go.step.sm/crypto/pemutil"
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
)
func Test_newSigner(t *testing.T) {
pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
pk, err := pemutil.ParseKey(pemBytes)
if err != nil {
t.Fatal(err)
}
type args struct {
c KeyManagementClient
signingKey string
}
tests := []struct {
name string
args args
want *Signer
wantErr bool
}{
{"ok", args{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}, "signingKey"}, &Signer{client: &MockClient{}, signingKey: "signingKey", publicKey: pk}, false},
{"fail get public key", args{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return nil, fmt.Errorf("an error")
},
}, "signingKey"}, nil, true},
{"fail parse pem", args{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string("bad pem")}, nil
},
}, "signingKey"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewSigner(tt.args.c, tt.args.signingKey)
if (err != nil) != tt.wantErr {
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
got.client = &MockClient{}
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
}
})
}
}
func Test_signer_Public(t *testing.T) {
pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
pk, err := pemutil.ParseKey(pemBytes)
if err != nil {
t.Fatal(err)
}
type fields struct {
client KeyManagementClient
signingKey string
publicKey crypto.PublicKey
}
tests := []struct {
name string
fields fields
want crypto.PublicKey
}{
{"ok", fields{&MockClient{}, "signingKey", pk}, pk},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Signer{
client: tt.fields.client,
signingKey: tt.fields.signingKey,
publicKey: tt.fields.publicKey,
}
if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("signer.Public() = %v, want %v", got, tt.want)
}
})
}
}
func Test_signer_Sign(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
okClient := &MockClient{
asymmetricSign: func(_ context.Context, _ *kmspb.AsymmetricSignRequest, _ ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error) {
return &kmspb.AsymmetricSignResponse{Signature: []byte("ok signature")}, nil
},
}
failClient := &MockClient{
asymmetricSign: func(_ context.Context, _ *kmspb.AsymmetricSignRequest, _ ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error) {
return nil, fmt.Errorf("an error")
},
}
type fields struct {
client KeyManagementClient
signingKey string
}
type args struct {
rand io.Reader
digest []byte
opts crypto.SignerOpts
}
tests := []struct {
name string
fields fields
args args
want []byte
wantErr bool
}{
{"ok sha256", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA256}, []byte("ok signature"), false},
{"ok sha384", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA384}, []byte("ok signature"), false},
{"ok sha512", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA512}, []byte("ok signature"), false},
{"fail MD5", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
{"fail asymmetric sign", fields{failClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Signer{
client: tt.fields.client,
signingKey: tt.fields.signingKey,
}
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("signer.Sign() = %v, want %v", got, tt.want)
}
})
}
}
func TestSigner_SignatureAlgorithm(t *testing.T) {
pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
client := &MockClient{
getPublicKey: func(_ context.Context, req *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
var algorithm kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm
switch req.Name {
case "ECDSA-SHA256":
algorithm = kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256
case "ECDSA-SHA384":
algorithm = kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384
case "SHA256-RSA-2048":
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256
case "SHA256-RSA-3072":
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256
case "SHA256-RSA-4096":
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256
case "SHA512-RSA-4096":
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512
case "SHA256-RSAPSS-2048":
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256
case "SHA256-RSAPSS-3072":
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256
case "SHA256-RSAPSS-4096":
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256
case "SHA512-RSAPSS-4096":
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512
}
return &kmspb.PublicKey{
Pem: string(pemBytes),
Algorithm: algorithm,
}, nil
},
}
if err != nil {
t.Fatal(err)
}
type fields struct {
client KeyManagementClient
signingKey string
}
tests := []struct {
name string
fields fields
want x509.SignatureAlgorithm
}{
{"ECDSA-SHA256", fields{client, "ECDSA-SHA256"}, x509.ECDSAWithSHA256},
{"ECDSA-SHA384", fields{client, "ECDSA-SHA384"}, x509.ECDSAWithSHA384},
{"SHA256-RSA-2048", fields{client, "SHA256-RSA-2048"}, x509.SHA256WithRSA},
{"SHA256-RSA-3072", fields{client, "SHA256-RSA-3072"}, x509.SHA256WithRSA},
{"SHA256-RSA-4096", fields{client, "SHA256-RSA-4096"}, x509.SHA256WithRSA},
{"SHA512-RSA-4096", fields{client, "SHA512-RSA-4096"}, x509.SHA512WithRSA},
{"SHA256-RSAPSS-2048", fields{client, "SHA256-RSAPSS-2048"}, x509.SHA256WithRSAPSS},
{"SHA256-RSAPSS-3072", fields{client, "SHA256-RSAPSS-3072"}, x509.SHA256WithRSAPSS},
{"SHA256-RSAPSS-4096", fields{client, "SHA256-RSAPSS-4096"}, x509.SHA256WithRSAPSS},
{"SHA512-RSAPSS-4096", fields{client, "SHA512-RSAPSS-4096"}, x509.SHA512WithRSAPSS},
{"unknown", fields{client, "UNKNOWN"}, x509.UnknownSignatureAlgorithm},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
signer, err := NewSigner(tt.fields.client, tt.fields.signingKey)
if err != nil {
t.Errorf("NewSigner() error = %v", err)
}
if got := signer.SignatureAlgorithm(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Signer.SignatureAlgorithm() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,4 +0,0 @@
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE5VPD/W5RXn0lrs2MdoNteTSZ+sh1
veT13hakPZF9YzaNVZgujqK3d1nt+4jPECU+ED/WQ1GgFZiVGUo3flvB/w==
-----END PUBLIC KEY-----

View file

@ -1,43 +0,0 @@
package kms
import (
"context"
"strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
// Enable default implementation
"github.com/smallstep/certificates/kms/softkms"
)
// KeyManager is the interface implemented by all the KMS.
type KeyManager = apiv1.KeyManager
// CertificateManager is the interface implemented by the KMS that can load and
// store x509.Certificates.
type CertificateManager = apiv1.CertificateManager
// Options are the KMS options. They represent the kms object in the ca.json.
type Options = apiv1.Options
// Default is the implementation of the default KMS.
var Default = &softkms.SoftKMS{}
// New initializes a new KMS from the given type.
func New(ctx context.Context, opts apiv1.Options) (KeyManager, error) {
if err := opts.Validate(); err != nil {
return nil, err
}
t := apiv1.Type(strings.ToLower(opts.Type))
if t == apiv1.DefaultKMS {
t = apiv1.SoftKMS
}
fn, ok := apiv1.LoadKeyManagerNewFunc(t)
if !ok {
return nil, errors.Errorf("unsupported kms type '%s'", t)
}
return fn(ctx, opts)
}

View file

@ -1,52 +0,0 @@
package kms
import (
"context"
"os"
"reflect"
"testing"
"github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/kms/awskms"
"github.com/smallstep/certificates/kms/cloudkms"
"github.com/smallstep/certificates/kms/softkms"
)
func TestNew(t *testing.T) {
ctx := context.Background()
type args struct {
ctx context.Context
opts apiv1.Options
}
tests := []struct {
name string
skipOnCI bool
args args
want KeyManager
wantErr bool
}{
{"softkms", false, args{ctx, apiv1.Options{Type: "softkms"}}, &softkms.SoftKMS{}, false},
{"default", false, args{ctx, apiv1.Options{}}, &softkms.SoftKMS{}, false},
{"awskms", false, args{ctx, apiv1.Options{Type: "awskms"}}, &awskms.KMS{}, false},
{"cloudkms", true, args{ctx, apiv1.Options{Type: "cloudkms"}}, &cloudkms.CloudKMS{}, true}, // fails because not credentials
{"pkcs11", false, args{ctx, apiv1.Options{Type: "pkcs11"}}, nil, true}, // not yet supported
{"fail validation", false, args{ctx, apiv1.Options{Type: "foobar"}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skipOnCI && os.Getenv("CI") == "true" {
t.SkipNow()
}
got, err := New(tt.args.ctx, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if reflect.TypeOf(got) != reflect.TypeOf(tt.want) {
t.Errorf("New() = %T, want %T", got, tt.want)
}
})
}
}

View file

@ -1,83 +0,0 @@
//go:build cgo
// +build cgo
package pkcs11
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"testing"
"github.com/smallstep/certificates/kms/apiv1"
)
func benchmarkSign(b *testing.B, signer crypto.Signer, opts crypto.SignerOpts) {
hash := opts.HashFunc()
h := hash.New()
h.Write([]byte("buggy-coheir-RUBRIC-rabbet-liberal-eaglet-khartoum-stagger"))
digest := h.Sum(nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
signer.Sign(rand.Reader, digest, opts)
}
b.StopTimer()
}
func BenchmarkSignRSA(b *testing.B) {
k := setupPKCS11(b)
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
SigningKey: "pkcs11:id=7371;object=rsa-key",
})
if err != nil {
b.Fatalf("PKCS11.CreateSigner() error = %v", err)
}
benchmarkSign(b, signer, crypto.SHA256)
}
func BenchmarkSignRSAPSS(b *testing.B) {
k := setupPKCS11(b)
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
SigningKey: "pkcs11:id=7372;object=rsa-pss-key",
})
if err != nil {
b.Fatalf("PKCS11.CreateSigner() error = %v", err)
}
benchmarkSign(b, signer, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
Hash: crypto.SHA256,
})
}
func BenchmarkSignP256(b *testing.B) {
k := setupPKCS11(b)
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
SigningKey: "pkcs11:id=7373;object=ecdsa-p256-key",
})
if err != nil {
b.Fatalf("PKCS11.CreateSigner() error = %v", err)
}
benchmarkSign(b, signer, crypto.SHA256)
}
func BenchmarkSignP384(b *testing.B) {
k := setupPKCS11(b)
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
SigningKey: "pkcs11:id=7374;object=ecdsa-p384-key",
})
if err != nil {
b.Fatalf("PKCS11.CreateSigner() error = %v", err)
}
benchmarkSign(b, signer, crypto.SHA384)
}
func BenchmarkSignP521(b *testing.B) {
k := setupPKCS11(b)
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
SigningKey: "pkcs11:id=7375;object=ecdsa-p521-key",
})
if err != nil {
b.Fatalf("PKCS11.CreateSigner() error = %v", err)
}
benchmarkSign(b, signer, crypto.SHA512)
}

View file

@ -1,64 +0,0 @@
//go:build opensc
// +build opensc
package pkcs11
import (
"runtime"
"sync"
"github.com/ThalesIgnite/crypto11"
)
var softHSM2Once sync.Once
// mustPKCS11 configures a *PKCS11 KMS to be used with OpenSC, using for example
// a Nitrokey HSM. To initialize these tests we should run:
//
// sc-hsm-tool --initialize --so-pin 3537363231383830 --pin 123456
//
// Or:
//
// pkcs11-tool --module /usr/local/lib/opensc-pkcs11.so \
// --init-token --init-pin \
// --so-pin=3537363231383830 --new-pin=123456 --pin=123456 \
// --label="pkcs11-test"
func mustPKCS11(t TBTesting) *PKCS11 {
t.Helper()
testModule = "OpenSC"
if runtime.GOARCH != "amd64" {
t.Fatalf("opensc test skipped on %s:%s", runtime.GOOS, runtime.GOARCH)
}
var path string
switch runtime.GOOS {
case "darwin":
path = "/usr/local/lib/opensc-pkcs11.so"
case "linux":
path = "/usr/local/lib/opensc-pkcs11.so"
default:
t.Skipf("opensc test skipped on %s", runtime.GOOS)
return nil
}
var zero int
p11, err := crypto11.Configure(&crypto11.Config{
Path: path,
SlotNumber: &zero,
Pin: "123456",
})
if err != nil {
t.Fatalf("failed to configure opensc on %s: %v", runtime.GOOS, err)
}
k := &PKCS11{
p11: p11,
}
// Setup
softHSM2Once.Do(func() {
teardown(t, k)
setup(t, k)
})
return k
}

View file

@ -1,210 +0,0 @@
//go:build cgo && !softhsm2 && !yubihsm2 && !opensc
// +build cgo,!softhsm2,!yubihsm2,!opensc
package pkcs11
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"io"
"math/big"
"github.com/ThalesIgnite/crypto11"
"github.com/pkg/errors"
)
func mustPKCS11(t TBTesting) *PKCS11 {
t.Helper()
testModule = "Golang crypto"
k := &PKCS11{
p11: &stubPKCS11{
signerIndex: make(map[keyType]int),
certIndex: make(map[keyType]int),
},
}
for i := range testCerts {
testCerts[i].Certificates = nil
}
teardown(t, k)
setup(t, k)
return k
}
type keyType struct {
id string
label string
serial string
}
func newKey(id, label []byte, serial *big.Int) keyType {
var serialString string
if serial != nil {
serialString = serial.String()
}
return keyType{
id: string(id),
label: string(label),
serial: serialString,
}
}
type stubPKCS11 struct {
signers []crypto11.Signer
certs []*x509.Certificate
signerIndex map[keyType]int
certIndex map[keyType]int
}
func (s *stubPKCS11) FindKeyPair(id, label []byte) (crypto11.Signer, error) {
if id == nil && label == nil {
return nil, errors.New("id and label cannot both be nil")
}
if i, ok := s.signerIndex[newKey(id, label, nil)]; ok {
return s.signers[i], nil
}
return nil, nil
}
func (s *stubPKCS11) FindCertificate(id, label []byte, serial *big.Int) (*x509.Certificate, error) {
if id == nil && label == nil && serial == nil {
return nil, errors.New("id, label and serial cannot both be nil")
}
if i, ok := s.certIndex[newKey(id, label, serial)]; ok {
return s.certs[i], nil
}
return nil, nil
}
func (s *stubPKCS11) ImportCertificateWithAttributes(template crypto11.AttributeSet, cert *x509.Certificate) error {
var id, label []byte
if v := template[crypto11.CkaId]; v != nil {
id = v.Value
}
if v := template[crypto11.CkaLabel]; v != nil {
label = v.Value
}
return s.ImportCertificateWithLabel(id, label, cert)
}
func (s *stubPKCS11) ImportCertificateWithLabel(id, label []byte, cert *x509.Certificate) error {
switch {
case id == nil:
return errors.New("id cannot both be nil")
case label == nil:
return errors.New("label cannot both be nil")
case cert == nil:
return errors.New("certificate cannot be nil")
}
i := len(s.certs)
s.certs = append(s.certs, cert)
s.certIndex[newKey(id, label, cert.SerialNumber)] = i
s.certIndex[newKey(id, nil, nil)] = i
s.certIndex[newKey(nil, label, nil)] = i
s.certIndex[newKey(nil, nil, cert.SerialNumber)] = i
s.certIndex[newKey(id, label, nil)] = i
s.certIndex[newKey(id, nil, cert.SerialNumber)] = i
s.certIndex[newKey(nil, label, cert.SerialNumber)] = i
return nil
}
func (s *stubPKCS11) DeleteCertificate(id, label []byte, serial *big.Int) error {
if id == nil && label == nil && serial == nil {
return errors.New("id, label and serial cannot both be nil")
}
if i, ok := s.certIndex[newKey(id, label, serial)]; ok {
s.certs[i] = nil
}
return nil
}
func (s *stubPKCS11) GenerateRSAKeyPairWithAttributes(public, private crypto11.AttributeSet, bits int) (crypto11.SignerDecrypter, error) {
var id, label []byte
if v := public[crypto11.CkaId]; v != nil {
id = v.Value
}
if v := public[crypto11.CkaLabel]; v != nil {
label = v.Value
}
return s.GenerateRSAKeyPairWithLabel(id, label, bits)
}
func (s *stubPKCS11) GenerateRSAKeyPairWithLabel(id, label []byte, bits int) (crypto11.SignerDecrypter, error) {
if id == nil && label == nil {
return nil, errors.New("id and label cannot both be nil")
}
p, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
k := &privateKey{
Signer: p,
index: len(s.signers),
stub: s,
}
s.signers = append(s.signers, k)
s.signerIndex[newKey(id, label, nil)] = k.index
s.signerIndex[newKey(id, nil, nil)] = k.index
s.signerIndex[newKey(nil, label, nil)] = k.index
return k, nil
}
func (s *stubPKCS11) GenerateECDSAKeyPairWithAttributes(public, private crypto11.AttributeSet, curve elliptic.Curve) (crypto11.Signer, error) {
var id, label []byte
if v := public[crypto11.CkaId]; v != nil {
id = v.Value
}
if v := public[crypto11.CkaLabel]; v != nil {
label = v.Value
}
return s.GenerateECDSAKeyPairWithLabel(id, label, curve)
}
func (s *stubPKCS11) GenerateECDSAKeyPairWithLabel(id, label []byte, curve elliptic.Curve) (crypto11.Signer, error) {
if id == nil && label == nil {
return nil, errors.New("id and label cannot both be nil")
}
p, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
return nil, err
}
k := &privateKey{
Signer: p,
index: len(s.signers),
stub: s,
}
s.signers = append(s.signers, k)
s.signerIndex[newKey(id, label, nil)] = k.index
s.signerIndex[newKey(id, nil, nil)] = k.index
s.signerIndex[newKey(nil, label, nil)] = k.index
return k, nil
}
func (s *stubPKCS11) Close() error {
return nil
}
type privateKey struct {
crypto.Signer
index int
stub *stubPKCS11
}
func (s *privateKey) Delete() error {
s.stub.signers[s.index] = nil
return nil
}
func (s *privateKey) Decrypt(rnd io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
k, ok := s.Signer.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("key is not an rsa key")
}
return k.Decrypt(rnd, msg, opts)
}

View file

@ -1,399 +0,0 @@
//go:build cgo
// +build cgo
package pkcs11
import (
"context"
"crypto"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
"encoding/hex"
"fmt"
"math/big"
"strconv"
"sync"
"github.com/ThalesIgnite/crypto11"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/kms/uri"
)
// Scheme is the scheme used in uris.
const Scheme = "pkcs11"
// DefaultRSASize is the number of bits of a new RSA key if no size has been
// specified.
const DefaultRSASize = 3072
// P11 defines the methods on crypto11.Context that this package will use. This
// interface will be used for unit testing.
type P11 interface {
FindKeyPair(id, label []byte) (crypto11.Signer, error)
FindCertificate(id, label []byte, serial *big.Int) (*x509.Certificate, error)
ImportCertificateWithAttributes(template crypto11.AttributeSet, certificate *x509.Certificate) error
DeleteCertificate(id, label []byte, serial *big.Int) error
GenerateRSAKeyPairWithAttributes(public, private crypto11.AttributeSet, bits int) (crypto11.SignerDecrypter, error)
GenerateECDSAKeyPairWithAttributes(public, private crypto11.AttributeSet, curve elliptic.Curve) (crypto11.Signer, error)
Close() error
}
var p11Configure = func(config *crypto11.Config) (P11, error) {
return crypto11.Configure(config)
}
// PKCS11 is the implementation of a KMS using the PKCS #11 standard.
type PKCS11 struct {
p11 P11
closed sync.Once
}
// New returns a new PKCS11 KMS.
func New(ctx context.Context, opts apiv1.Options) (*PKCS11, error) {
var config crypto11.Config
if opts.URI != "" {
u, err := uri.ParseWithScheme(Scheme, opts.URI)
if err != nil {
return nil, err
}
config.Pin = u.Pin()
config.Path = u.Get("module-path")
config.TokenLabel = u.Get("token")
config.TokenSerial = u.Get("serial")
if v := u.Get("slot-id"); v != "" {
n, err := strconv.Atoi(v)
if err != nil {
return nil, errors.Wrap(err, "kms uri 'slot-id' is not valid")
}
config.SlotNumber = &n
}
}
if config.Pin == "" && opts.Pin != "" {
config.Pin = opts.Pin
}
switch {
case config.Path == "":
return nil, errors.New("kms uri 'module-path' are required")
case config.TokenLabel == "" && config.TokenSerial == "" && config.SlotNumber == nil:
return nil, errors.New("kms uri 'token', 'serial' or 'slot-id' are required")
case config.Pin == "":
return nil, errors.New("kms 'pin' cannot be empty")
case config.TokenLabel != "" && config.TokenSerial != "":
return nil, errors.New("kms uri 'token' and 'serial' are mutually exclusive")
case config.TokenLabel != "" && config.SlotNumber != nil:
return nil, errors.New("kms uri 'token' and 'slot-id' are mutually exclusive")
case config.TokenSerial != "" && config.SlotNumber != nil:
return nil, errors.New("kms uri 'serial' and 'slot-id' are mutually exclusive")
}
p11, err := p11Configure(&config)
if err != nil {
return nil, errors.Wrap(err, "error initializing PKCS#11")
}
return &PKCS11{
p11: p11,
}, nil
}
func init() {
apiv1.Register(apiv1.PKCS11, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
return New(ctx, opts)
})
}
// GetPublicKey returns the public key ....
func (k *PKCS11) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
if req.Name == "" {
return nil, errors.New("getPublicKeyRequest 'name' cannot be empty")
}
signer, err := findSigner(k.p11, req.Name)
if err != nil {
return nil, errors.Wrap(err, "getPublicKey failed")
}
return signer.Public(), nil
}
// CreateKey generates a new key in the PKCS#11 module and returns the public key.
func (k *PKCS11) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
switch {
case req.Name == "":
return nil, errors.New("createKeyRequest 'name' cannot be empty")
case req.Bits < 0:
return nil, errors.New("createKeyRequest 'bits' cannot be negative")
}
signer, err := generateKey(k.p11, req)
if err != nil {
return nil, errors.Wrap(err, "createKey failed")
}
return &apiv1.CreateKeyResponse{
Name: req.Name,
PublicKey: signer.Public(),
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: req.Name,
},
}, nil
}
// CreateSigner creates a signer using a key present in the PKCS#11 module.
func (k *PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
if req.SigningKey == "" {
return nil, errors.New("createSignerRequest 'signingKey' cannot be empty")
}
signer, err := findSigner(k.p11, req.SigningKey)
if err != nil {
return nil, errors.Wrap(err, "createSigner failed")
}
return signer, nil
}
// CreateDecrypter creates a decrypter using a key present in the PKCS#11
// module.
func (k *PKCS11) CreateDecrypter(req *apiv1.CreateDecrypterRequest) (crypto.Decrypter, error) {
if req.DecryptionKey == "" {
return nil, errors.New("createDecrypterRequest 'decryptionKey' cannot be empty")
}
signer, err := findSigner(k.p11, req.DecryptionKey)
if err != nil {
return nil, errors.Wrap(err, "createDecrypterRequest failed")
}
// Only RSA keys will implement the Decrypter interface.
if _, ok := signer.Public().(*rsa.PublicKey); ok {
if dec, ok := signer.(crypto.Decrypter); ok {
return dec, nil
}
}
return nil, errors.New("createDecrypterRequest failed: signer does not implement crypto.Decrypter")
}
// LoadCertificate implements kms.CertificateManager and loads a certificate
// from the YubiKey.
func (k *PKCS11) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) {
if req.Name == "" {
return nil, errors.New("loadCertificateRequest 'name' cannot be nil")
}
cert, err := findCertificate(k.p11, req.Name)
if err != nil {
return nil, errors.Wrap(err, "loadCertificate failed")
}
return cert, nil
}
// StoreCertificate implements kms.CertificateManager and stores a certificate
// in the YubiKey.
func (k *PKCS11) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
switch {
case req.Name == "":
return errors.New("storeCertificateRequest 'name' cannot be empty")
case req.Certificate == nil:
return errors.New("storeCertificateRequest 'Certificate' cannot be nil")
}
id, object, err := parseObject(req.Name)
if err != nil {
return errors.Wrap(err, "storeCertificate failed")
}
// Enforce the use of both id and labels. This is not strictly necessary in
// PKCS #11, but it's a good practice.
if len(id) == 0 || len(object) == 0 {
return errors.Errorf("key with uri %s is not valid, id and object are required", req.Name)
}
cert, err := k.p11.FindCertificate(id, object, nil)
if err != nil {
return errors.Wrap(err, "storeCertificate failed")
}
if cert != nil {
return errors.Wrap(apiv1.ErrAlreadyExists{
Message: req.Name + " already exists",
}, "storeCertificate failed")
}
// Import certificate with the necessary attributes.
template, err := crypto11.NewAttributeSetWithIDAndLabel(id, object)
if err != nil {
return errors.Wrap(err, "storeCertificate failed")
}
if req.Extractable {
template.Set(crypto11.CkaExtractable, true)
}
if err := k.p11.ImportCertificateWithAttributes(template, req.Certificate); err != nil {
return errors.Wrap(err, "storeCertificate failed")
}
return nil
}
// DeleteKey is a utility function to delete a key given an uri.
func (k *PKCS11) DeleteKey(u string) error {
id, object, err := parseObject(u)
if err != nil {
return errors.Wrap(err, "deleteKey failed")
}
signer, err := k.p11.FindKeyPair(id, object)
if err != nil {
return errors.Wrap(err, "deleteKey failed")
}
if signer == nil {
return nil
}
if err := signer.Delete(); err != nil {
return errors.Wrap(err, "deleteKey failed")
}
return nil
}
// DeleteCertificate is a utility function to delete a certificate given an uri.
func (k *PKCS11) DeleteCertificate(u string) error {
id, object, err := parseObject(u)
if err != nil {
return errors.Wrap(err, "deleteCertificate failed")
}
if err := k.p11.DeleteCertificate(id, object, nil); err != nil {
return errors.Wrap(err, "deleteCertificate failed")
}
return nil
}
// Close releases the connection to the PKCS#11 module.
func (k *PKCS11) Close() (err error) {
k.closed.Do(func() {
err = errors.Wrap(k.p11.Close(), "error closing pkcs#11 context")
})
return
}
func toByte(s string) []byte {
if s == "" {
return nil
}
return []byte(s)
}
func parseObject(rawuri string) ([]byte, []byte, error) {
u, err := uri.ParseWithScheme(Scheme, rawuri)
if err != nil {
return nil, nil, err
}
id := u.GetEncoded("id")
object := u.Get("object")
if len(id) == 0 && object == "" {
return nil, nil, errors.Errorf("key with uri %s is not valid, id or object are required", rawuri)
}
return id, toByte(object), nil
}
func generateKey(ctx P11, req *apiv1.CreateKeyRequest) (crypto11.Signer, error) {
id, object, err := parseObject(req.Name)
if err != nil {
return nil, err
}
signer, err := ctx.FindKeyPair(id, object)
if err != nil {
return nil, err
}
if signer != nil {
return nil, apiv1.ErrAlreadyExists{
Message: req.Name + " already exists",
}
}
// Enforce the use of both id and labels. This is not strictly necessary in
// PKCS #11, but it's a good practice.
if len(id) == 0 || len(object) == 0 {
return nil, errors.Errorf("key with uri %s is not valid, id and object are required", req.Name)
}
// Create template for public and private keys
public, err := crypto11.NewAttributeSetWithIDAndLabel(id, object)
if err != nil {
return nil, err
}
private := public.Copy()
if req.Extractable {
private.Set(crypto11.CkaExtractable, true)
}
bits := req.Bits
if bits == 0 {
bits = DefaultRSASize
}
switch req.SignatureAlgorithm {
case apiv1.UnspecifiedSignAlgorithm:
return ctx.GenerateECDSAKeyPairWithAttributes(public, private, elliptic.P256())
case apiv1.SHA256WithRSA, apiv1.SHA384WithRSA, apiv1.SHA512WithRSA:
return ctx.GenerateRSAKeyPairWithAttributes(public, private, bits)
case apiv1.SHA256WithRSAPSS, apiv1.SHA384WithRSAPSS, apiv1.SHA512WithRSAPSS:
return ctx.GenerateRSAKeyPairWithAttributes(public, private, bits)
case apiv1.ECDSAWithSHA256:
return ctx.GenerateECDSAKeyPairWithAttributes(public, private, elliptic.P256())
case apiv1.ECDSAWithSHA384:
return ctx.GenerateECDSAKeyPairWithAttributes(public, private, elliptic.P384())
case apiv1.ECDSAWithSHA512:
return ctx.GenerateECDSAKeyPairWithAttributes(public, private, elliptic.P521())
case apiv1.PureEd25519:
return nil, fmt.Errorf("signature algorithm %s is not supported", req.SignatureAlgorithm)
default:
return nil, fmt.Errorf("signature algorithm %s is not supported", req.SignatureAlgorithm)
}
}
func findSigner(ctx P11, rawuri string) (crypto11.Signer, error) {
id, object, err := parseObject(rawuri)
if err != nil {
return nil, err
}
signer, err := ctx.FindKeyPair(id, object)
if err != nil {
return nil, errors.Wrapf(err, "error finding key with uri %s", rawuri)
}
if signer == nil {
return nil, errors.Errorf("key with uri %s not found", rawuri)
}
return signer, nil
}
func findCertificate(ctx P11, rawuri string) (*x509.Certificate, error) {
u, err := uri.ParseWithScheme(Scheme, rawuri)
if err != nil {
return nil, err
}
id, object, serial := u.GetEncoded("id"), u.Get("object"), u.Get("serial")
if len(id) == 0 && object == "" && serial == "" {
return nil, errors.Errorf("key with uri %s is not valid, id, object or serial are required", rawuri)
}
var serialNumber *big.Int
if serial != "" {
b, err := hex.DecodeString(serial)
if err != nil {
return nil, errors.Errorf("key with uri %s is not valid, failed to decode serial", rawuri)
}
serialNumber = new(big.Int).SetBytes(b)
}
cert, err := ctx.FindCertificate(id, toByte(object), serialNumber)
if err != nil {
return nil, errors.Wrapf(err, "error finding certificate with uri %s", rawuri)
}
if cert == nil {
return nil, errors.Errorf("certificate with uri %s not found", rawuri)
}
return cert, nil
}

View file

@ -1,58 +0,0 @@
//go:build !cgo
// +build !cgo
package pkcs11
import (
"context"
"crypto"
"os"
"path/filepath"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
)
var errUnsupported error
func init() {
name := filepath.Base(os.Args[0])
errUnsupported = errors.Errorf("unsupported kms type 'pkcs11': %s is compiled without cgo support", name)
apiv1.Register(apiv1.PKCS11, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
return nil, errUnsupported
})
}
// PKCS11 is the implementation of a KMS using the PKCS #11 standard.
type PKCS11 struct{}
// New implements the kms.KeyManager interface and without CGO will always
// return an error.
func New(ctx context.Context, opts apiv1.Options) (*PKCS11, error) {
return nil, errUnsupported
}
// GetPublicKey implements the kms.KeyManager interface and without CGO will always
// return an error.
func (*PKCS11) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
return nil, errUnsupported
}
// CreateKey implements the kms.KeyManager interface and without CGO will always
// return an error.
func (*PKCS11) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
return nil, errUnsupported
}
// CreateSigner implements the kms.KeyManager interface and without CGO will always
// return an error.
func (*PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
return nil, errUnsupported
}
// Close implements the kms.KeyManager interface and without CGO will always
// return an error.
func (*PKCS11) Close() error {
return errUnsupported
}

View file

@ -1,836 +0,0 @@
//go:build cgo
// +build cgo
package pkcs11
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"math/big"
"reflect"
"strings"
"testing"
"github.com/ThalesIgnite/crypto11"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
func TestNew(t *testing.T) {
tmp := p11Configure
t.Cleanup(func() {
p11Configure = tmp
})
k := mustPKCS11(t)
t.Cleanup(func() {
k.Close()
})
p11Configure = func(config *crypto11.Config) (P11, error) {
if strings.Contains(config.Path, "fail") {
return nil, errors.New("an error")
}
return k.p11, nil
}
type args struct {
ctx context.Context
opts apiv1.Options
}
tests := []struct {
name string
args args
want *PKCS11
wantErr bool
}{
{"ok", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test?pin-value=password",
}}, k, false},
{"ok with serial", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;serial=0123456789?pin-value=password",
}}, k, false},
{"ok with slot-id", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;slot-id=0?pin-value=password",
}}, k, false},
{"ok with pin", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test",
Pin: "passowrd",
}}, k, false},
{"fail missing module", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:token=pkcs11-test",
Pin: "passowrd",
}}, nil, true},
{"fail missing pin", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test",
}}, nil, true},
{"fail missing token/serial/slot-id", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so",
Pin: "passowrd",
}}, nil, true},
{"fail token+serial+slot-id", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;serial=0123456789;slot-id=0",
Pin: "passowrd",
}}, nil, true},
{"fail token+serial", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;serial=0123456789",
Pin: "passowrd",
}}, nil, true},
{"fail token+slot-id", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;slot-id=0",
Pin: "passowrd",
}}, nil, true},
{"fail serial+slot-id", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;serial=0123456789;slot-id=0",
Pin: "passowrd",
}}, nil, true},
{"fail slot-id", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;slot-id=x?pin-value=password",
}}, nil, true},
{"fail scheme", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "foo:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test?pin-value=password",
}}, nil, true},
{"fail configure", args{context.Background(), apiv1.Options{
Type: "pkcs11",
URI: "pkcs11:module-path=/usr/local/lib/fail.so;token=pkcs11-test?pin-value=password",
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := New(tt.args.ctx, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("New() = %v, want %v", got, tt.want)
}
})
}
}
func TestPKCS11_GetPublicKey(t *testing.T) {
k := setupPKCS11(t)
type args struct {
req *apiv1.GetPublicKeyRequest
}
tests := []struct {
name string
args args
want crypto.PublicKey
wantErr bool
}{
{"RSA", args{&apiv1.GetPublicKeyRequest{
Name: "pkcs11:id=7371;object=rsa-key",
}}, &rsa.PublicKey{}, false},
{"RSA by id", args{&apiv1.GetPublicKeyRequest{
Name: "pkcs11:id=7371",
}}, &rsa.PublicKey{}, false},
{"RSA by label", args{&apiv1.GetPublicKeyRequest{
Name: "pkcs11:object=rsa-key",
}}, &rsa.PublicKey{}, false},
{"ECDSA", args{&apiv1.GetPublicKeyRequest{
Name: "pkcs11:id=7373;object=ecdsa-p256-key",
}}, &ecdsa.PublicKey{}, false},
{"ECDSA by id", args{&apiv1.GetPublicKeyRequest{
Name: "pkcs11:id=7373",
}}, &ecdsa.PublicKey{}, false},
{"ECDSA by label", args{&apiv1.GetPublicKeyRequest{
Name: "pkcs11:object=ecdsa-p256-key",
}}, &ecdsa.PublicKey{}, false},
{"fail name", args{&apiv1.GetPublicKeyRequest{
Name: "",
}}, nil, true},
{"fail uri", args{&apiv1.GetPublicKeyRequest{
Name: "https:id=9999;object=https",
}}, nil, true},
{"fail missing", args{&apiv1.GetPublicKeyRequest{
Name: "pkcs11:id=9999;object=rsa-key",
}}, nil, true},
{"fail FindKeyPair", args{&apiv1.GetPublicKeyRequest{
Name: "pkcs11:foo=bar",
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := k.GetPublicKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("PKCS11.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if reflect.TypeOf(got) != reflect.TypeOf(tt.want) {
t.Errorf("PKCS11.GetPublicKey() = %T, want %T", got, tt.want)
}
})
}
}
func TestPKCS11_CreateKey(t *testing.T) {
k := setupPKCS11(t)
// Make sure to delete the created key
k.DeleteKey(testObject)
type args struct {
req *apiv1.CreateKeyRequest
}
tests := []struct {
name string
args args
want *apiv1.CreateKeyResponse
wantErr bool
}{
{"default", args{&apiv1.CreateKeyRequest{
Name: testObject,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &ecdsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"default extractable", args{&apiv1.CreateKeyRequest{
Name: testObject,
Extractable: true,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &ecdsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"RSA SHA256WithRSA", args{&apiv1.CreateKeyRequest{
Name: testObject,
SignatureAlgorithm: apiv1.SHA256WithRSA,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &rsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"RSA SHA384WithRSA", args{&apiv1.CreateKeyRequest{
Name: testObject,
SignatureAlgorithm: apiv1.SHA384WithRSA,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &rsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"RSA SHA512WithRSA", args{&apiv1.CreateKeyRequest{
Name: testObject,
SignatureAlgorithm: apiv1.SHA512WithRSA,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &rsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"RSA SHA256WithRSAPSS", args{&apiv1.CreateKeyRequest{
Name: testObject,
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &rsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"RSA SHA384WithRSAPSS", args{&apiv1.CreateKeyRequest{
Name: testObject,
SignatureAlgorithm: apiv1.SHA384WithRSAPSS,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &rsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"RSA SHA512WithRSAPSS", args{&apiv1.CreateKeyRequest{
Name: testObject,
SignatureAlgorithm: apiv1.SHA512WithRSAPSS,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &rsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"RSA 2048", args{&apiv1.CreateKeyRequest{
Name: testObject,
SignatureAlgorithm: apiv1.SHA256WithRSA,
Bits: 2048,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &rsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"RSA 4096", args{&apiv1.CreateKeyRequest{
Name: testObject,
SignatureAlgorithm: apiv1.SHA256WithRSA,
Bits: 4096,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &rsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"ECDSA P256", args{&apiv1.CreateKeyRequest{
Name: testObject,
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &ecdsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"ECDSA P384", args{&apiv1.CreateKeyRequest{
Name: testObject,
SignatureAlgorithm: apiv1.ECDSAWithSHA384,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &ecdsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"ECDSA P521", args{&apiv1.CreateKeyRequest{
Name: testObject,
SignatureAlgorithm: apiv1.ECDSAWithSHA512,
}}, &apiv1.CreateKeyResponse{
Name: testObject,
PublicKey: &ecdsa.PublicKey{},
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: testObject,
},
}, false},
{"fail name", args{&apiv1.CreateKeyRequest{
Name: "",
}}, nil, true},
{"fail no id", args{&apiv1.CreateKeyRequest{
Name: "pkcs11:object=create-key",
}}, nil, true},
{"fail no object", args{&apiv1.CreateKeyRequest{
Name: "pkcs11:id=9999",
}}, nil, true},
{"fail schema", args{&apiv1.CreateKeyRequest{
Name: "pkcs12:id=9999;object=create-key",
}}, nil, true},
{"fail bits", args{&apiv1.CreateKeyRequest{
Name: "pkcs11:id=9999;object=create-key",
Bits: -1,
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
}}, nil, true},
{"fail ed25519", args{&apiv1.CreateKeyRequest{
Name: "pkcs11:id=9999;object=create-key",
SignatureAlgorithm: apiv1.PureEd25519,
}}, nil, true},
{"fail unknown", args{&apiv1.CreateKeyRequest{
Name: "pkcs11:id=9999;object=create-key",
SignatureAlgorithm: apiv1.SignatureAlgorithm(100),
}}, nil, true},
{"fail FindKeyPair", args{&apiv1.CreateKeyRequest{
Name: "pkcs11:foo=bar",
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
}}, nil, true},
{"fail already exists", args{&apiv1.CreateKeyRequest{
Name: "pkcs11:id=7373;object=ecdsa-p256-key",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := k.CreateKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("PKCS11.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
got.PublicKey = tt.want.PublicKey
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("PKCS11.CreateKey() = %v, want %v", got, tt.want)
}
if got != nil {
if err := k.DeleteKey(got.Name); err != nil {
t.Errorf("PKCS11.DeleteKey() error = %v", err)
}
}
})
}
}
func TestPKCS11_CreateSigner(t *testing.T) {
k := setupPKCS11(t)
data := []byte("buggy-coheir-RUBRIC-rabbet-liberal-eaglet-khartoum-stagger")
// VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the
// public key, pub. Its return value records whether the signature is valid.
verifyASN1 := func(pub *ecdsa.PublicKey, hash, sig []byte) bool {
var (
r, s = &big.Int{}, &big.Int{}
inner cryptobyte.String
)
input := cryptobyte.String(sig)
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
!input.Empty() ||
!inner.ReadASN1Integer(r) ||
!inner.ReadASN1Integer(s) ||
!inner.Empty() {
return false
}
return ecdsa.Verify(pub, hash, r, s)
}
type args struct {
req *apiv1.CreateSignerRequest
}
tests := []struct {
name string
args args
algorithm apiv1.SignatureAlgorithm
signerOpts crypto.SignerOpts
wantErr bool
}{
// SoftHSM2
{"RSA", args{&apiv1.CreateSignerRequest{
SigningKey: "pkcs11:id=7371;object=rsa-key",
}}, apiv1.SHA256WithRSA, crypto.SHA256, false},
{"RSA PSS", args{&apiv1.CreateSignerRequest{
SigningKey: "pkcs11:id=7372;object=rsa-pss-key",
}}, apiv1.SHA256WithRSAPSS, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
Hash: crypto.SHA256,
}, false},
{"ECDSA P256", args{&apiv1.CreateSignerRequest{
SigningKey: "pkcs11:id=7373;object=ecdsa-p256-key",
}}, apiv1.ECDSAWithSHA256, crypto.SHA256, false},
{"ECDSA P384", args{&apiv1.CreateSignerRequest{
SigningKey: "pkcs11:id=7374;object=ecdsa-p384-key",
}}, apiv1.ECDSAWithSHA384, crypto.SHA384, false},
{"ECDSA P521", args{&apiv1.CreateSignerRequest{
SigningKey: "pkcs11:id=7375;object=ecdsa-p521-key",
}}, apiv1.ECDSAWithSHA512, crypto.SHA512, false},
{"fail SigningKey", args{&apiv1.CreateSignerRequest{
SigningKey: "",
}}, 0, nil, true},
{"fail uri", args{&apiv1.CreateSignerRequest{
SigningKey: "https:id=7375;object=ecdsa-p521-key",
}}, 0, nil, true},
{"fail FindKeyPair", args{&apiv1.CreateSignerRequest{
SigningKey: "pkcs11:foo=bar",
}}, 0, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := k.CreateSigner(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("PKCS11.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
hash := tt.signerOpts.HashFunc()
h := hash.New()
h.Write(data)
digest := h.Sum(nil)
sig, err := got.Sign(rand.Reader, digest, tt.signerOpts)
if err != nil {
t.Errorf("cyrpto.Signer.Sign() error = %v", err)
}
switch tt.algorithm {
case apiv1.SHA256WithRSA, apiv1.SHA384WithRSA, apiv1.SHA512WithRSA:
pub := got.Public().(*rsa.PublicKey)
if err := rsa.VerifyPKCS1v15(pub, hash, digest, sig); err != nil {
t.Errorf("rsa.VerifyPKCS1v15() error = %v", err)
}
case apiv1.UnspecifiedSignAlgorithm, apiv1.SHA256WithRSAPSS, apiv1.SHA384WithRSAPSS, apiv1.SHA512WithRSAPSS:
pub := got.Public().(*rsa.PublicKey)
if err := rsa.VerifyPSS(pub, hash, digest, sig, tt.signerOpts.(*rsa.PSSOptions)); err != nil {
t.Errorf("rsa.VerifyPSS() error = %v", err)
}
case apiv1.ECDSAWithSHA256, apiv1.ECDSAWithSHA384, apiv1.ECDSAWithSHA512:
pub := got.Public().(*ecdsa.PublicKey)
if !verifyASN1(pub, digest, sig) {
t.Error("ecdsa.VerifyASN1() failed")
}
default:
t.Errorf("signature algorithm %s is not supported", tt.algorithm)
}
}
})
}
}
func TestPKCS11_CreateDecrypter(t *testing.T) {
k := setupPKCS11(t)
data := []byte("buggy-coheir-RUBRIC-rabbet-liberal-eaglet-khartoum-stagger")
type args struct {
req *apiv1.CreateDecrypterRequest
}
tests := []struct {
name string
args args
wantErr bool
}{
{"RSA", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:id=7371;object=rsa-key",
}}, false},
{"RSA PSS", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:id=7372;object=rsa-pss-key",
}}, false},
{"ECDSA P256", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:id=7373;object=ecdsa-p256-key",
}}, true},
{"ECDSA P384", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:id=7374;object=ecdsa-p384-key",
}}, true},
{"ECDSA P521", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:id=7375;object=ecdsa-p521-key",
}}, true},
{"fail DecryptionKey", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "",
}}, true},
{"fail uri", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "https:id=7375;object=ecdsa-p521-key",
}}, true},
{"fail FindKeyPair", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:foo=bar",
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := k.CreateDecrypter(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("PKCS11.CreateDecrypter() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
pub := got.Public().(*rsa.PublicKey)
// PKCS#1 v1.5
enc, err := rsa.EncryptPKCS1v15(rand.Reader, pub, data)
if err != nil {
t.Errorf("rsa.EncryptPKCS1v15() error = %v", err)
return
}
dec, err := got.Decrypt(rand.Reader, enc, nil)
if err != nil {
t.Errorf("PKCS1v15.Decrypt() error = %v", err)
} else if !bytes.Equal(dec, data) {
t.Errorf("PKCS1v15.Decrypt() failed got = %s, want = %s", dec, data)
}
// RSA-OAEP
enc, err = rsa.EncryptOAEP(crypto.SHA256.New(), rand.Reader, pub, data, []byte("label"))
if err != nil {
t.Errorf("rsa.EncryptOAEP() error = %v", err)
return
}
dec, err = got.Decrypt(rand.Reader, enc, &rsa.OAEPOptions{
Hash: crypto.SHA256,
Label: []byte("label"),
})
if err != nil {
t.Errorf("RSA-OAEP.Decrypt() error = %v", err)
} else if !bytes.Equal(dec, data) {
t.Errorf("RSA-OAEP.Decrypt() RSA-OAEP failed got = %s, want = %s", dec, data)
}
}
})
}
}
func TestPKCS11_LoadCertificate(t *testing.T) {
k := setupPKCS11(t)
getCertFn := func(i, j int) func() *x509.Certificate {
return func() *x509.Certificate {
return testCerts[i].Certificates[j]
}
}
type args struct {
req *apiv1.LoadCertificateRequest
}
tests := []struct {
name string
args args
wantFn func() *x509.Certificate
wantErr bool
}{
{"load", args{&apiv1.LoadCertificateRequest{
Name: "pkcs11:id=7376;object=test-root",
}}, getCertFn(0, 0), false},
{"load by id", args{&apiv1.LoadCertificateRequest{
Name: "pkcs11:id=7376",
}}, getCertFn(0, 0), false},
{"load by label", args{&apiv1.LoadCertificateRequest{
Name: "pkcs11:object=test-root",
}}, getCertFn(0, 0), false},
{"load by serial", args{&apiv1.LoadCertificateRequest{
Name: "pkcs11:serial=64",
}}, getCertFn(0, 0), false},
{"fail missing", args{&apiv1.LoadCertificateRequest{
Name: "pkcs11:id=9999;object=test-root",
}}, nil, true},
{"fail name", args{&apiv1.LoadCertificateRequest{
Name: "",
}}, nil, true},
{"fail scheme", args{&apiv1.LoadCertificateRequest{
Name: "foo:id=7376;object=test-root",
}}, nil, true},
{"fail serial", args{&apiv1.LoadCertificateRequest{
Name: "pkcs11:serial=foo",
}}, nil, true},
{"fail FindCertificate", args{&apiv1.LoadCertificateRequest{
Name: "pkcs11:foo=bar",
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := k.LoadCertificate(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("PKCS11.LoadCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
var want *x509.Certificate
if tt.wantFn != nil {
want = tt.wantFn()
got.Raw, got.RawIssuer, got.RawSubject, got.RawTBSCertificate, got.RawSubjectPublicKeyInfo = nil, nil, nil, nil, nil
want.Raw, want.RawIssuer, want.RawSubject, want.RawTBSCertificate, want.RawSubjectPublicKeyInfo = nil, nil, nil, nil, nil
}
if !reflect.DeepEqual(got, want) {
t.Errorf("PKCS11.LoadCertificate() = %v, want %v", got, want)
}
})
}
}
func TestPKCS11_StoreCertificate(t *testing.T) {
k := setupPKCS11(t)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("ed25519.GenerateKey() error = %v", err)
}
cert, err := generateCertificate(pub, priv)
if err != nil {
t.Fatalf("x509.CreateCertificate() error = %v", err)
}
// Make sure to delete the created certificate
t.Cleanup(func() {
k.DeleteCertificate(testObject)
k.DeleteCertificate(testObjectAlt)
})
type args struct {
req *apiv1.StoreCertificateRequest
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{&apiv1.StoreCertificateRequest{
Name: testObject,
Certificate: cert,
}}, false},
{"ok extractable", args{&apiv1.StoreCertificateRequest{
Name: testObjectAlt,
Certificate: cert,
Extractable: true,
}}, false},
{"fail already exists", args{&apiv1.StoreCertificateRequest{
Name: testObject,
Certificate: cert,
}}, true},
{"fail name", args{&apiv1.StoreCertificateRequest{
Name: "",
Certificate: cert,
}}, true},
{"fail certificate", args{&apiv1.StoreCertificateRequest{
Name: testObject,
Certificate: nil,
}}, true},
{"fail uri", args{&apiv1.StoreCertificateRequest{
Name: "http:id=7770;object=create-cert",
Certificate: cert,
}}, true},
{"fail missing id", args{&apiv1.StoreCertificateRequest{
Name: "pkcs11:object=create-cert",
Certificate: cert,
}}, true},
{"fail missing object", args{&apiv1.StoreCertificateRequest{
Name: "pkcs11:id=7770;object=",
Certificate: cert,
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.args.req.Extractable {
if testModule == "SoftHSM2" {
t.Skip("Extractable certificates are not supported on SoftHSM2")
}
}
if err := k.StoreCertificate(tt.args.req); (err != nil) != tt.wantErr {
t.Errorf("PKCS11.StoreCertificate() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr {
got, err := k.LoadCertificate(&apiv1.LoadCertificateRequest{
Name: tt.args.req.Name,
})
if err != nil {
t.Errorf("PKCS11.LoadCertificate() error = %v", err)
}
if !reflect.DeepEqual(got, cert) {
t.Errorf("PKCS11.LoadCertificate() = %v, want %v", got, cert)
}
}
})
}
}
func TestPKCS11_DeleteKey(t *testing.T) {
k := setupPKCS11(t)
type args struct {
uri string
}
tests := []struct {
name string
args args
wantErr bool
}{
{"delete", args{testObject}, false},
{"delete by id", args{testObjectByID}, false},
{"delete by label", args{testObjectByLabel}, false},
{"delete missing", args{"pkcs11:id=9999;object=missing-key"}, false},
{"fail name", args{""}, true},
{"fail FindKeyPair", args{"pkcs11:foo=bar"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := k.CreateKey(&apiv1.CreateKeyRequest{
Name: testObject,
}); err != nil {
t.Fatalf("PKCS1.CreateKey() error = %v", err)
}
if err := k.DeleteKey(tt.args.uri); (err != nil) != tt.wantErr {
t.Errorf("PKCS11.DeleteKey() error = %v, wantErr %v", err, tt.wantErr)
}
if _, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
Name: tt.args.uri,
}); err == nil {
t.Error("PKCS11.GetPublicKey() public key found and not expected")
}
// Make sure to delete the created one.
if err := k.DeleteKey(testObject); err != nil {
t.Errorf("PKCS11.DeleteKey() error = %v", err)
}
})
}
}
func TestPKCS11_DeleteCertificate(t *testing.T) {
k := setupPKCS11(t)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("ed25519.GenerateKey() error = %v", err)
}
cert, err := generateCertificate(pub, priv)
if err != nil {
t.Fatalf("x509.CreateCertificate() error = %v", err)
}
type args struct {
uri string
}
tests := []struct {
name string
args args
wantErr bool
}{
{"delete", args{testObject}, false},
{"delete by id", args{testObjectByID}, false},
{"delete by label", args{testObjectByLabel}, false},
{"delete missing", args{"pkcs11:id=9999;object=missing-key"}, false},
{"fail name", args{""}, true},
{"fail DeleteCertificate", args{"pkcs11:foo=bar"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := k.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: testObject,
Certificate: cert,
}); err != nil {
t.Fatalf("PKCS11.StoreCertificate() error = %v", err)
}
if err := k.DeleteCertificate(tt.args.uri); (err != nil) != tt.wantErr {
t.Errorf("PKCS11.DeleteCertificate() error = %v, wantErr %v", err, tt.wantErr)
}
if _, err := k.LoadCertificate(&apiv1.LoadCertificateRequest{
Name: tt.args.uri,
}); err == nil {
t.Error("PKCS11.LoadCertificate() certificate found and not expected")
}
// Make sure to delete the created one.
if err := k.DeleteCertificate(testObject); err != nil {
t.Errorf("PKCS11.DeleteCertificate() error = %v", err)
}
})
}
}
func TestPKCS11_Close(t *testing.T) {
k := mustPKCS11(t)
tests := []struct {
name string
wantErr bool
}{
{"ok", false},
{"second", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := k.Close(); (err != nil) != tt.wantErr {
t.Errorf("PKCS11.Close() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -1,145 +0,0 @@
//go:build cgo
// +build cgo
package pkcs11
import (
"crypto"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
)
var (
testModule = ""
testObject = "pkcs11:id=7370;object=test-name"
testObjectAlt = "pkcs11:id=7377;object=alt-test-name"
testObjectByID = "pkcs11:id=7370"
testObjectByLabel = "pkcs11:object=test-name"
testKeys = []struct {
Name string
SignatureAlgorithm apiv1.SignatureAlgorithm
Bits int
}{
{"pkcs11:id=7371;object=rsa-key", apiv1.SHA256WithRSA, 2048},
{"pkcs11:id=7372;object=rsa-pss-key", apiv1.SHA256WithRSAPSS, DefaultRSASize},
{"pkcs11:id=7373;object=ecdsa-p256-key", apiv1.ECDSAWithSHA256, 0},
{"pkcs11:id=7374;object=ecdsa-p384-key", apiv1.ECDSAWithSHA384, 0},
{"pkcs11:id=7375;object=ecdsa-p521-key", apiv1.ECDSAWithSHA512, 0},
}
testCerts = []struct {
Name string
Key string
Certificates []*x509.Certificate
}{
{"pkcs11:id=7376;object=test-root", "pkcs11:id=7373;object=ecdsa-p256-key", nil},
}
)
type TBTesting interface {
Helper()
Cleanup(f func())
Log(args ...interface{})
Errorf(format string, args ...interface{})
Fatalf(format string, args ...interface{})
Skipf(format string, args ...interface{})
}
func generateCertificate(pub crypto.PublicKey, signer crypto.Signer) (*x509.Certificate, error) {
now := time.Now()
template := &x509.Certificate{
Subject: pkix.Name{CommonName: "Test Root Certificate"},
Issuer: pkix.Name{CommonName: "Test Root Certificate"},
IsCA: true,
MaxPathLen: 1,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
NotBefore: now,
NotAfter: now.Add(time.Hour),
SerialNumber: big.NewInt(100),
}
b, err := x509.CreateCertificate(rand.Reader, template, template, pub, signer)
if err != nil {
return nil, err
}
return x509.ParseCertificate(b)
}
func setup(t TBTesting, k *PKCS11) {
t.Log("Running using", testModule)
for _, tk := range testKeys {
_, err := k.CreateKey(&apiv1.CreateKeyRequest{
Name: tk.Name,
SignatureAlgorithm: tk.SignatureAlgorithm,
Bits: tk.Bits,
})
if err != nil && !errors.Is(errors.Cause(err), apiv1.ErrAlreadyExists{
Message: tk.Name + " already exists",
}) {
t.Errorf("PKCS11.GetPublicKey() error = %v", err)
}
}
for i, c := range testCerts {
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
SigningKey: c.Key,
})
if err != nil {
t.Errorf("PKCS11.CreateSigner() error = %v", err)
continue
}
cert, err := generateCertificate(signer.Public(), signer)
if err != nil {
t.Errorf("x509.CreateCertificate() error = %v", err)
continue
}
if err := k.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.Name,
Certificate: cert,
}); err != nil && !errors.Is(errors.Cause(err), apiv1.ErrAlreadyExists{
Message: c.Name + " already exists",
}) {
t.Errorf("PKCS1.StoreCertificate() error = %v", err)
continue
}
testCerts[i].Certificates = append(testCerts[i].Certificates, cert)
}
}
func teardown(t TBTesting, k *PKCS11) {
testObjects := []string{testObject, testObjectByID, testObjectByLabel}
for _, name := range testObjects {
if err := k.DeleteKey(name); err != nil {
t.Errorf("PKCS11.DeleteKey() error = %v", err)
}
if err := k.DeleteCertificate(name); err != nil {
t.Errorf("PKCS11.DeleteCertificate() error = %v", err)
}
}
for _, tk := range testKeys {
if err := k.DeleteKey(tk.Name); err != nil {
t.Errorf("PKCS11.DeleteKey() error = %v", err)
}
}
for _, tc := range testCerts {
if err := k.DeleteCertificate(tc.Name); err != nil {
t.Errorf("PKCS11.DeleteCertificate() error = %v", err)
}
}
}
func setupPKCS11(t TBTesting) *PKCS11 {
t.Helper()
k := mustPKCS11(t)
t.Cleanup(func() {
k.Close()
})
return k
}

View file

@ -1,62 +0,0 @@
//go:build cgo && softhsm2
// +build cgo,softhsm2
package pkcs11
import (
"runtime"
"sync"
"github.com/ThalesIgnite/crypto11"
)
var softHSM2Once sync.Once
// mustPKCS11 configures a *PKCS11 KMS to be used with SoftHSM2. To initialize
// these tests, we should run:
//
// softhsm2-util --init-token --free \
// --token pkcs11-test --label pkcs11-test \
// --so-pin password --pin password
//
// To delete we should run:
//
// softhsm2-util --delete-token --token pkcs11-test
func mustPKCS11(t TBTesting) *PKCS11 {
t.Helper()
testModule = "SoftHSM2"
if runtime.GOARCH != "amd64" {
t.Fatalf("softHSM2 test skipped on %s:%s", runtime.GOOS, runtime.GOARCH)
}
var path string
switch runtime.GOOS {
case "darwin":
path = "/usr/local/lib/softhsm/libsofthsm2.so"
case "linux":
path = "/usr/lib/softhsm/libsofthsm2.so"
default:
t.Skipf("softHSM2 test skipped on %s", runtime.GOOS)
return nil
}
p11, err := crypto11.Configure(&crypto11.Config{
Path: path,
TokenLabel: "pkcs11-test",
Pin: "password",
})
if err != nil {
t.Fatalf("failed to configure softHSM2 on %s: %v", runtime.GOOS, err)
}
k := &PKCS11{
p11: p11,
}
// Setup
softHSM2Once.Do(func() {
teardown(t, k)
setup(t, k)
})
return k
}

View file

@ -1,56 +0,0 @@
//go:build cgo && yubihsm2
// +build cgo,yubihsm2
package pkcs11
import (
"runtime"
"sync"
"github.com/ThalesIgnite/crypto11"
)
var yubiHSM2Once sync.Once
// mustPKCS11 configures a *PKCS11 KMS to be used with YubiHSM2. To initialize
// these tests, we should run:
//
// yubihsm-connector -d
func mustPKCS11(t TBTesting) *PKCS11 {
t.Helper()
testModule = "YubiHSM2"
if runtime.GOARCH != "amd64" {
t.Skipf("yubiHSM2 test skipped on %s:%s", runtime.GOOS, runtime.GOARCH)
}
var path string
switch runtime.GOOS {
case "darwin":
path = "/usr/local/lib/pkcs11/yubihsm_pkcs11.dylib"
case "linux":
path = "/usr/lib/x86_64-linux-gnu/pkcs11/yubihsm_pkcs11.so"
default:
t.Skipf("yubiHSM2 test skipped on %s", runtime.GOOS)
return nil
}
p11, err := crypto11.Configure(&crypto11.Config{
Path: path,
TokenLabel: "YubiHSM",
Pin: "0001password",
})
if err != nil {
t.Fatalf("failed to configure YubiHSM2 on %s: %v", runtime.GOOS, err)
}
k := &PKCS11{
p11: p11,
}
// Setup
yubiHSM2Once.Do(func() {
teardown(t, k)
setup(t, k)
})
return k
}

View file

@ -1,183 +0,0 @@
package softkms
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
"go.step.sm/cli-utils/ui"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil"
)
type algorithmAttributes struct {
Type string
Curve string
}
// DefaultRSAKeySize is the default size for RSA keys.
const DefaultRSAKeySize = 3072
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]algorithmAttributes{
apiv1.UnspecifiedSignAlgorithm: {"EC", "P-256"},
apiv1.SHA256WithRSA: {"RSA", ""},
apiv1.SHA384WithRSA: {"RSA", ""},
apiv1.SHA512WithRSA: {"RSA", ""},
apiv1.SHA256WithRSAPSS: {"RSA", ""},
apiv1.SHA384WithRSAPSS: {"RSA", ""},
apiv1.SHA512WithRSAPSS: {"RSA", ""},
apiv1.ECDSAWithSHA256: {"EC", "P-256"},
apiv1.ECDSAWithSHA384: {"EC", "P-384"},
apiv1.ECDSAWithSHA512: {"EC", "P-521"},
apiv1.PureEd25519: {"OKP", "Ed25519"},
}
// generateKey is used for testing purposes.
var generateKey = func(kty, crv string, size int) (interface{}, interface{}, error) {
if kty == "RSA" && size == 0 {
size = DefaultRSAKeySize
}
return keyutil.GenerateKeyPair(kty, crv, size)
}
// SoftKMS is a key manager that uses keys stored in disk.
type SoftKMS struct{}
// New returns a new SoftKMS.
func New(ctx context.Context, opts apiv1.Options) (*SoftKMS, error) {
return &SoftKMS{}, nil
}
func init() {
pemutil.PromptPassword = func(msg string) ([]byte, error) {
return ui.PromptPassword(msg)
}
apiv1.Register(apiv1.SoftKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
return New(ctx, opts)
})
}
// Close is a noop that just returns nil.
func (k *SoftKMS) Close() error {
return nil
}
// CreateSigner returns a new signer configured with the given signing key.
func (k *SoftKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
var opts []pemutil.Options
if req.Password != nil {
opts = append(opts, pemutil.WithPassword(req.Password))
}
switch {
case req.Signer != nil:
return req.Signer, nil
case len(req.SigningKeyPEM) != 0:
v, err := pemutil.ParseKey(req.SigningKeyPEM, opts...)
if err != nil {
return nil, err
}
sig, ok := v.(crypto.Signer)
if !ok {
return nil, errors.New("signingKeyPEM is not a crypto.Signer")
}
return sig, nil
case req.SigningKey != "":
v, err := pemutil.Read(req.SigningKey, opts...)
if err != nil {
return nil, err
}
sig, ok := v.(crypto.Signer)
if !ok {
return nil, errors.New("signingKey is not a crypto.Signer")
}
return sig, nil
default:
return nil, errors.New("failed to load softKMS: please define signingKeyPEM or signingKey")
}
}
// CreateKey generates a new key using Golang crypto and returns both public and
// private key.
func (k *SoftKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
v, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
if !ok {
return nil, errors.Errorf("softKMS does not support signature algorithm '%s'", req.SignatureAlgorithm)
}
pub, priv, err := generateKey(v.Type, v.Curve, req.Bits)
if err != nil {
return nil, err
}
signer, ok := priv.(crypto.Signer)
if !ok {
return nil, errors.Errorf("softKMS createKey result is not a crypto.Signer: type %T", priv)
}
return &apiv1.CreateKeyResponse{
Name: req.Name,
PublicKey: pub,
PrivateKey: priv,
CreateSignerRequest: apiv1.CreateSignerRequest{
Signer: signer,
},
}, nil
}
// GetPublicKey returns the public key from the file passed in the request name.
func (k *SoftKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
v, err := pemutil.Read(req.Name)
if err != nil {
return nil, err
}
switch vv := v.(type) {
case *x509.Certificate:
return vv.PublicKey, nil
case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
return vv, nil
default:
return nil, errors.Errorf("unsupported public key type %T", v)
}
}
// CreateDecrypter creates a new crypto.Decrypter backed by disk/software
func (k *SoftKMS) CreateDecrypter(req *apiv1.CreateDecrypterRequest) (crypto.Decrypter, error) {
var opts []pemutil.Options
if req.Password != nil {
opts = append(opts, pemutil.WithPassword(req.Password))
}
switch {
case req.Decrypter != nil:
return req.Decrypter, nil
case len(req.DecryptionKeyPEM) != 0:
v, err := pemutil.ParseKey(req.DecryptionKeyPEM, opts...)
if err != nil {
return nil, err
}
decrypter, ok := v.(crypto.Decrypter)
if !ok {
return nil, errors.New("decryptorKeyPEM is not a crypto.Decrypter")
}
return decrypter, nil
case req.DecryptionKey != "":
v, err := pemutil.Read(req.DecryptionKey, opts...)
if err != nil {
return nil, err
}
decrypter, ok := v.(crypto.Decrypter)
if !ok {
return nil, errors.New("decryptionKey is not a crypto.Decrypter")
}
return decrypter, nil
default:
return nil, errors.New("failed to load softKMS: please define decryptionKeyPEM or decryptionKey")
}
}

View file

@ -1,381 +0,0 @@
package softkms
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"os"
"reflect"
"testing"
"github.com/smallstep/certificates/kms/apiv1"
"go.step.sm/crypto/pemutil"
)
func TestNew(t *testing.T) {
type args struct {
ctx context.Context
opts apiv1.Options
}
tests := []struct {
name string
args args
want *SoftKMS
wantErr bool
}{
{"ok", args{context.Background(), apiv1.Options{}}, &SoftKMS{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := New(tt.args.ctx, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("New() = %v, want %v", got, tt.want)
}
})
}
}
func TestSoftKMS_Close(t *testing.T) {
tests := []struct {
name string
wantErr bool
}{
{"ok", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &SoftKMS{}
if err := k.Close(); (err != nil) != tt.wantErr {
t.Errorf("SoftKMS.Close() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestSoftKMS_CreateSigner(t *testing.T) {
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
pemBlock, err := pemutil.Serialize(pk)
if err != nil {
t.Fatal(err)
}
pemBlockPassword, err := pemutil.Serialize(pk, pemutil.WithPassword([]byte("pass")))
if err != nil {
t.Fatal(err)
}
// Read and decode file using standard packages
b, err := os.ReadFile("testdata/priv.pem")
if err != nil {
t.Fatal(err)
}
block, _ := pem.Decode(b)
block.Bytes, err = x509.DecryptPEMBlock(block, []byte("pass")) //nolint
if err != nil {
t.Fatal(err)
}
pk2, err := x509.ParseECPrivateKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
// Create a public PEM
b, err = x509.MarshalPKIXPublicKey(pk.Public())
if err != nil {
t.Fatal(err)
}
pub := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: b,
})
type args struct {
req *apiv1.CreateSignerRequest
}
tests := []struct {
name string
args args
want crypto.Signer
wantErr bool
}{
{"signer", args{&apiv1.CreateSignerRequest{Signer: pk}}, pk, false},
{"pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlock)}}, pk, false},
{"pem password", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlockPassword), Password: []byte("pass")}}, pk, false},
{"file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("pass")}}, pk2, false},
{"fail", args{&apiv1.CreateSignerRequest{}}, nil, true},
{"fail bad pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: []byte("bad pem")}}, nil, true},
{"fail bad password", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("bad-pass")}}, nil, true},
{"fail not a signer", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pub}}, nil, true},
{"fail not a signer from file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/pub.pem"}}, nil, true},
{"fail missing", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/missing"}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &SoftKMS{}
got, err := k.CreateSigner(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("SoftKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SoftKMS.CreateSigner() = %v, want %v", got, tt.want)
}
})
}
}
func restoreGenerateKey() func() {
oldGenerateKey := generateKey
return func() {
generateKey = oldGenerateKey
}
}
func TestSoftKMS_CreateKey(t *testing.T) {
fn := restoreGenerateKey()
defer fn()
p256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
edpub, edpriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
type args struct {
req *apiv1.CreateKeyRequest
}
type params struct {
kty string
crv string
size int
}
tests := []struct {
name string
args args
generateKey func() (interface{}, interface{}, error)
want *apiv1.CreateKeyResponse
wantParams params
wantErr bool
}{
{"p256", args{&apiv1.CreateKeyRequest{Name: "p256", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
return p256.Public(), p256, nil
}, &apiv1.CreateKeyResponse{Name: "p256", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
{"rsa", args{&apiv1.CreateKeyRequest{Name: "rsa3072", SignatureAlgorithm: apiv1.SHA256WithRSA}}, func() (interface{}, interface{}, error) {
return rsa2048.Public(), rsa2048, nil
}, &apiv1.CreateKeyResponse{Name: "rsa3072", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 0}, false},
{"rsa2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 2048}}, func() (interface{}, interface{}, error) {
return rsa2048.Public(), rsa2048, nil
}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
{"rsaPSS2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSAPSS, Bits: 2048}}, func() (interface{}, interface{}, error) {
return rsa2048.Public(), rsa2048, nil
}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
{"ed25519", args{&apiv1.CreateKeyRequest{Name: "ed25519", SignatureAlgorithm: apiv1.PureEd25519}}, func() (interface{}, interface{}, error) {
return edpub, edpriv, nil
}, &apiv1.CreateKeyResponse{Name: "ed25519", PublicKey: edpub, PrivateKey: edpriv, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: edpriv}}, params{"OKP", "Ed25519", 0}, false},
{"default", args{&apiv1.CreateKeyRequest{Name: "default"}}, func() (interface{}, interface{}, error) {
return p256.Public(), p256, nil
}, &apiv1.CreateKeyResponse{Name: "default", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
{"fail algorithm", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.SignatureAlgorithm(100)}}, func() (interface{}, interface{}, error) {
return p256.Public(), p256, nil
}, nil, params{}, true},
{"fail generate key", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
return nil, nil, fmt.Errorf("an error")
}, nil, params{"EC", "P-256", 0}, true},
{"fail no signer", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
return 1, 2, nil
}, nil, params{"EC", "P-256", 0}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &SoftKMS{}
generateKey = func(kty, crv string, size int) (interface{}, interface{}, error) {
if tt.wantParams.kty != kty {
t.Errorf("GenerateKey() kty = %s, want %s", kty, tt.wantParams.kty)
}
if tt.wantParams.crv != crv {
t.Errorf("GenerateKey() crv = %s, want %s", crv, tt.wantParams.crv)
}
if tt.wantParams.size != size {
t.Errorf("GenerateKey() size = %d, want %d", size, tt.wantParams.size)
}
return tt.generateKey()
}
got, err := k.CreateKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("SoftKMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SoftKMS.CreateKey() = %v, want %v", got, tt.want)
}
})
}
}
func TestSoftKMS_GetPublicKey(t *testing.T) {
b, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
block, _ := pem.Decode(b)
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
type args struct {
req *apiv1.GetPublicKeyRequest
}
tests := []struct {
name string
args args
want crypto.PublicKey
wantErr bool
}{
{"key", args{&apiv1.GetPublicKeyRequest{Name: "testdata/pub.pem"}}, pub, false},
{"cert", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.crt"}}, pub, false},
{"fail not exists", args{&apiv1.GetPublicKeyRequest{Name: "testdata/missing"}}, nil, true},
{"fail type", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.key"}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &SoftKMS{}
got, err := k.GetPublicKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("SoftKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SoftKMS.GetPublicKey() = %v, want %v", got, tt.want)
}
})
}
}
func Test_generateKey(t *testing.T) {
type args struct {
kty string
crv string
size int
}
tests := []struct {
name string
args args
wantType interface{}
wantType1 interface{}
wantErr bool
}{
{"rsa2048", args{"RSA", "", 0}, &rsa.PublicKey{}, &rsa.PrivateKey{}, false},
{"rsa2048", args{"RSA", "", 2048}, &rsa.PublicKey{}, &rsa.PrivateKey{}, false},
{"p256", args{"EC", "P-256", 0}, &ecdsa.PublicKey{}, &ecdsa.PrivateKey{}, false},
{"ed25519", args{"OKP", "Ed25519", 0}, ed25519.PublicKey{}, ed25519.PrivateKey{}, false},
{"fail kty", args{"FOO", "", 0}, nil, nil, true},
{"fail crv", args{"EC", "P-123", 0}, nil, nil, true},
{"fail size", args{"RSA", "", 1}, nil, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1, err := generateKey(tt.args.kty, tt.args.crv, tt.args.size)
if (err != nil) != tt.wantErr {
t.Errorf("generateKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if reflect.TypeOf(got) != reflect.TypeOf(tt.wantType) {
t.Errorf("generateKey() got = %T, want %T", got, tt.wantType)
}
if reflect.TypeOf(got1) != reflect.TypeOf(tt.wantType1) {
t.Errorf("generateKey() got1 = %T, want %T", got1, tt.wantType1)
}
})
}
}
func TestSoftKMS_CreateDecrypter(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
pemBlock, err := pemutil.Serialize(privateKey)
if err != nil {
t.Fatal(err)
}
pemBlockPassword, err := pemutil.Serialize(privateKey, pemutil.WithPassword([]byte("pass")))
if err != nil {
t.Fatal(err)
}
ecdsaPK, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
ecdsaPemBlock, err := pemutil.Serialize(ecdsaPK)
if err != nil {
t.Fatal(err)
}
b, err := os.ReadFile("testdata/rsa.priv.pem")
if err != nil {
t.Fatal(err)
}
block, _ := pem.Decode(b)
block.Bytes, err = x509.DecryptPEMBlock(block, []byte("pass")) //nolint
if err != nil {
t.Fatal(err)
}
keyFromFile, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
type args struct {
req *apiv1.CreateDecrypterRequest
}
tests := []struct {
name string
args args
want crypto.Decrypter
wantErr bool
}{
{"decrypter", args{&apiv1.CreateDecrypterRequest{Decrypter: privateKey}}, privateKey, false},
{"file", args{&apiv1.CreateDecrypterRequest{DecryptionKey: "testdata/rsa.priv.pem", Password: []byte("pass")}}, keyFromFile, false},
{"pem", args{&apiv1.CreateDecrypterRequest{DecryptionKeyPEM: pem.EncodeToMemory(pemBlock)}}, privateKey, false},
{"pem password", args{&apiv1.CreateDecrypterRequest{DecryptionKeyPEM: pem.EncodeToMemory(pemBlockPassword), Password: []byte("pass")}}, privateKey, false},
{"fail none", args{&apiv1.CreateDecrypterRequest{}}, nil, true},
{"fail missing", args{&apiv1.CreateDecrypterRequest{DecryptionKey: "testdata/missing"}}, nil, true},
{"fail bad pem", args{&apiv1.CreateDecrypterRequest{DecryptionKeyPEM: []byte("bad pem")}}, nil, true},
{"fail bad password", args{&apiv1.CreateDecrypterRequest{DecryptionKeyPEM: pem.EncodeToMemory(pemBlockPassword), Password: []byte("bad-pass")}}, nil, true},
{"fail not a decrypter (ecdsa key)", args{&apiv1.CreateDecrypterRequest{DecryptionKeyPEM: pem.EncodeToMemory(ecdsaPemBlock)}}, nil, true},
{"fail not a decrypter from file", args{&apiv1.CreateDecrypterRequest{DecryptionKey: "testdata/rsa.pub.pem"}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &SoftKMS{}
got, err := k.CreateDecrypter(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("SoftKMS.CreateDecrypter(), error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SoftKMS.CreateDecrypter() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,11 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIBpzCCAU2gAwIBAgIQWaY8KIDAfak8aYljelf8eTAKBggqhkjOPQQDAjAdMRsw
GQYDVQQDExJ0ZXN0LnNtYWxsc3RlcC5jb20wHhcNMjAwMTE2MDAwNDU4WhcNMjAw
MTE3MDAwNDU4WjAdMRswGQYDVQQDExJ0ZXN0LnNtYWxsc3RlcC5jb20wWTATBgcq
hkjOPQIBBggqhkjOPQMBBwNCAATlU8P9blFefSWuzYx2g215NJn6yHW95PXeFqQ9
kX1jNo1VmC6Oord3We37iM8QJT4QP9ZDUaAVmJUZSjd+W8H/o28wbTAOBgNVHQ8B
Af8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQW
BBTn0wonKkm2lLRNYZrKhUukiynvqzAdBgNVHREEFjAUghJ0ZXN0LnNtYWxsc3Rl
cC5jb20wCgYIKoZIzj0EAwIDSAAwRQIhAJ5XqryBIY1X4fl/9l0isV69eQfA0Qo5
1mjervUcEnOWAiBsmN4frz5YVw7i4UXChVBeZLZfJOKvn5eyh2gEzoq1+w==
-----END CERTIFICATE-----

View file

@ -1,5 +0,0 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEICB6lIrMa9fVQJtdAYS4qmdYQ1BHJsEQDx8zxL38gA8toAoGCCqGSM49
AwEHoUQDQgAE5VPD/W5RXn0lrs2MdoNteTSZ+sh1veT13hakPZF9YzaNVZgujqK3
d1nt+4jPECU+ED/WQ1GgFZiVGUo3flvB/w==
-----END EC PRIVATE KEY-----

View file

@ -1,8 +0,0 @@
-----BEGIN EC PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-256-CBC,1fcec5dfbf3327f61bfe5ab6ae8a0626
V39b/pNHMbP80TXSHLsUY6UOTCzf3KwIxvj1e7S9brNMJJc9b3UiloMBJIYBkl00
NKI8JU4jSlcerR58DqsTHIELiX6a+RJLe3/iR2/5Gru+CmmWJ68jQu872WCgh6Ms
o8TzhyGx74ETmdKn5CdtylsnKMa9heW3tBLFAbNCgKc=
-----END EC PRIVATE KEY-----

View file

@ -1,4 +0,0 @@
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE5VPD/W5RXn0lrs2MdoNteTSZ+sh1
veT13hakPZF9YzaNVZgujqK3d1nt+4jPECU+ED/WQ1GgFZiVGUo3flvB/w==
-----END PUBLIC KEY-----

View file

@ -1,30 +0,0 @@
-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-256-CBC,dff7bfd0e0163a4cd7ade8f68b966699
jtmOhr2zo244Oq2fVsShZAUoQZ1gi6Iwc4i0sReU66XP9CFkdvJasAfkjQGrbCEy
m2+r7W6aH+L3j/4sXcJe8h4UVnnC4DHCozmtqqFCq7cFS4TiVpco26wEVH5WLm7Y
3Ew/pL0k24E+Ycf+yV5c1tQXRlmsKubjwzrZtGZP2yn3Dxsu97mzOXAfx7r+DIKI
5a4S3m1/yXw76tt6Iho9h4huA25UUDHKUQvOGd5gmOKqJRV9djoyu85ODbmz5nt0
pB2EzdHOrefgd0rcQQPI1uFBWqASJxTn+uS7ZBP4rlCcs932lI1mPerMh1ujo51F
3aibrwhKE6kaJyOOnUbvyBnaiTb5i4WwTqx/jfsOsggXQb3UlxgDph48VXw8O2jF
CQmle+TR8yr1A14/Dno5Dd4cqPv6AmWWU2zolvLxKQixFcvjsyQYCDajWWRPkOgj
RTKXDqL1mpjrlDqcSXzemCWk6FzqdUQhimhFgARDRfRwwDeWQN5ua4a3gnem/cpA
ZS8J45H0ZC/CxGPfp+qx75n5a875+n4VMmCZerXPzEIj1CzS7D6BVAXTHJaNIB6S
0WNfQnftp09O2l6iXBE+MHt5bVxqt46+vgcceSu7Gsb3ZfD79vnQ7tR+wb+xmHKk
8rVcMrB+kDRXVguH/a3zUGYAEnb6hPkIJywJVD4G65oM+D9D67Mdka8wIMK48doV
my8a0MfT/9AidR6XJVxIkHlPsPzlxirm/NKF7oSlzurcvYcPAYnHYLW2uB8dyidq
1zB+3rxbSYCVqrhqzN4prydGvkIE3/+AJyIGn7uGSTSSyF6BC9APXQaHplRGKwLz
efOIMoEwXJ1DIcKmk9GB65xxrZxMu3Cclcbc4PgY4370G0PfCHuUQNQL2RUWCQn0
aax+qDiFg1LsLRaI75OaLJ+uKs6rRfytQMmFGqK/b6iVbktiYWMtrDJDo4OUTtZ6
LBBySH7sAFgI3IIxct2Fwg8X1J4kfHr9jWTLjMEIE2o8cyqvSQ8rdwA25MxRcn75
DGqSlGE6Sx0XhWCVUiZidVRSYGKmOmH9yw8cjKm17qL23t8Gwns4Xunl7V6YlTCG
BPw5f1jWCQ94TwvUSuHMPYoXlYwRoe+jfDAzp2AQwXqvWX5Qno5PKz9gQ5iYacZ/
k82fyPbk2XLDkPnaNJKnyiIc252O0WffUlX6Rlv3aF8ZgVvWfZbuHEK6g1W+IKSA
pXAQ+iZBl+fjs/wT0yZSNTB0P1InD9Ve536L94gxXoeMr6F0Eouk3J2R9qdFp0Av
31xylRKSmzUf87/sRxjy3FzSTjIal77y1euJoAEU/nShmNrAZ6B8wnlvHfVwbgmt
xWqxYIi/j/C8Led9uhEhX2WjPsO7ckGA41Tw6hZk/5hr4jmPoZQKHf9OauJFujMh
ybPRQ6SGZJaYQAgpEGHSHFm8lwf5/DcezdSMdzqAKBWJBv6MediMuS60wcJ0Tebk
rdLkNE4bsxfc889BkXBrSqfd+Auu5RcF/kF44gLL7oj4ojQyV44vLZbC4+liGThT
bhayYGV64hsY+zL03u5wVfF1Y+33/uc8o/0JjbfuW5AIdikVES/jnKKFXSTMNL69
-----END RSA PRIVATE KEY-----

View file

@ -1,9 +0,0 @@
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAn2Oh7/uWB5RH40la1a43
IRaLZ8EnJVw5DCKE3BUre8xflVY2wTIS7XHcY0fEGprtq7hzFKors9AIGGn2yGrf
bZX2I+1g+RtQ6cLL6koeLuhRDqCuae0lZPulWc5ixBmM9mpl4ARRcpQFldxFRhis
xUaHMx8VqdZjFSDc5CJHYYK1n2G5DyuzJCk6yOfyMpwxizZJF4IUyqV7zKmZv1z9
/Xd8X0ag7jRdaTBpupJ1WLaq7LlvyB4nr47JXXkLFbRIL1F/gTcPtg0tdEZiKnxs
VLKwOs3VjhEorUwhmVxr4NnNX/0tuOY1FJ0mx5jKLAevqLVwK2JIg/f3h7JcNxDy
tQIDAQAB
-----END PUBLIC KEY-----

View file

@ -1,206 +0,0 @@
package sshagentkms
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"io"
"net"
"os"
"strings"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
"go.step.sm/crypto/pemutil"
)
// SSHAgentKMS is a key manager that uses keys provided by ssh-agent
type SSHAgentKMS struct {
agentClient agent.Agent
}
// New returns a new SSHAgentKMS.
func New(ctx context.Context, opts apiv1.Options) (*SSHAgentKMS, error) {
socket := os.Getenv("SSH_AUTH_SOCK")
conn, err := net.Dial("unix", socket)
if err != nil {
return nil, errors.Wrap(err, "failed to open SSH_AUTH_SOCK")
}
agentClient := agent.NewClient(conn)
return &SSHAgentKMS{
agentClient: agentClient,
}, nil
}
// NewFromAgent initializes an SSHAgentKMS from a given agent, this method is
// used for testing purposes.
func NewFromAgent(ctx context.Context, opts apiv1.Options, agentClient agent.Agent) (*SSHAgentKMS, error) {
return &SSHAgentKMS{
agentClient: agentClient,
}, nil
}
func init() {
apiv1.Register(apiv1.SSHAgentKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
return New(ctx, opts)
})
}
// Close closes the agent. This is a noop for the SSHAgentKMS.
func (k *SSHAgentKMS) Close() error {
return nil
}
// WrappedSSHSigner is a utility type to wrap a ssh.Signer as a crypto.Signer
type WrappedSSHSigner struct {
Sshsigner ssh.Signer
}
// Public returns the agent public key. The type of this public key is
// *agent.Key.
func (s *WrappedSSHSigner) Public() crypto.PublicKey {
return s.Sshsigner.PublicKey()
}
// Sign signs the given digest using the ssh agent and returns the signature.
func (s *WrappedSSHSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
sig, err := s.Sshsigner.Sign(rand, digest)
if err != nil {
return nil, err
}
return sig.Blob, nil
}
// NewWrappedSignerFromSSHSigner returns a new crypto signer wrapping the given
// one.
func NewWrappedSignerFromSSHSigner(signer ssh.Signer) crypto.Signer {
return &WrappedSSHSigner{signer}
}
func (k *SSHAgentKMS) findKey(signingKey string) (target int, err error) {
if strings.HasPrefix(signingKey, "sshagentkms:") {
var key = strings.TrimPrefix(signingKey, "sshagentkms:")
l, err := k.agentClient.List()
if err != nil {
return -1, err
}
for i, s := range l {
if s.Comment == key {
return i, nil
}
}
}
return -1, errors.Errorf("SSHAgentKMS couldn't find %s", signingKey)
}
// CreateSigner returns a new signer configured with the given signing key.
func (k *SSHAgentKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
if req.Signer != nil {
return req.Signer, nil
}
if strings.HasPrefix(req.SigningKey, "sshagentkms:") {
target, err := k.findKey(req.SigningKey)
if err != nil {
return nil, err
}
s, err := k.agentClient.Signers()
if err != nil {
return nil, err
}
return NewWrappedSignerFromSSHSigner(s[target]), nil
}
// OK: We don't actually care about non-ssh certificates,
// but we can't disable it in step-ca so this code is copy-pasted from
// softkms just to keep step-ca happy.
var opts []pemutil.Options
if req.Password != nil {
opts = append(opts, pemutil.WithPassword(req.Password))
}
switch {
case len(req.SigningKeyPEM) != 0:
v, err := pemutil.ParseKey(req.SigningKeyPEM, opts...)
if err != nil {
return nil, err
}
sig, ok := v.(crypto.Signer)
if !ok {
return nil, errors.New("signingKeyPEM is not a crypto.Signer")
}
return sig, nil
case req.SigningKey != "":
v, err := pemutil.Read(req.SigningKey, opts...)
if err != nil {
return nil, err
}
sig, ok := v.(crypto.Signer)
if !ok {
return nil, errors.New("signingKey is not a crypto.Signer")
}
return sig, nil
default:
return nil, errors.New("failed to load softKMS: please define signingKeyPEM or signingKey")
}
}
// CreateKey generates a new key and returns both public and private key.
func (k *SSHAgentKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
return nil, errors.Errorf("SSHAgentKMS doesn't support generating keys")
}
// GetPublicKey returns the public key from the file passed in the request name.
func (k *SSHAgentKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
var v crypto.PublicKey
if strings.HasPrefix(req.Name, "sshagentkms:") {
target, err := k.findKey(req.Name)
if err != nil {
return nil, err
}
s, err := k.agentClient.Signers()
if err != nil {
return nil, err
}
sshPub := s[target].PublicKey()
sshPubBytes := sshPub.Marshal()
parsed, err := ssh.ParsePublicKey(sshPubBytes)
if err != nil {
return nil, err
}
parsedCryptoKey := parsed.(ssh.CryptoPublicKey)
// Then, we can call CryptoPublicKey() to get the actual crypto.PublicKey
v = parsedCryptoKey.CryptoPublicKey()
} else {
var err error
v, err = pemutil.Read(req.Name)
if err != nil {
return nil, err
}
}
switch vv := v.(type) {
case *x509.Certificate:
return vv.PublicKey, nil
case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
return vv, nil
default:
return nil, errors.Errorf("unsupported public key type %T", v)
}
}

View file

@ -1,609 +0,0 @@
package sshagentkms
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"net"
"os"
"os/exec"
"path/filepath"
"reflect"
"strconv"
"strings"
"testing"
"github.com/smallstep/certificates/kms/apiv1"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"go.step.sm/crypto/pemutil"
)
// Some helpers with inspiration from crypto/ssh/agent/client_test.go
// startOpenSSHAgent executes ssh-agent, and returns an Agent interface to it.
func startOpenSSHAgent(t *testing.T) (client agent.Agent, socket string, cleanup func()) {
/* Always test with OpenSSHAgent
if testing.Short() {
// ssh-agent is not always available, and the key
// types supported vary by platform.
t.Skip("skipping test due to -short")
}
*/
bin, err := exec.LookPath("ssh-agent")
if err != nil {
t.Skip("could not find ssh-agent")
}
cmd := exec.Command(bin, "-s")
cmd.Env = []string{} // Do not let the user's environment influence ssh-agent behavior.
cmd.Stderr = new(bytes.Buffer)
out, err := cmd.Output()
if err != nil {
t.Fatalf("%s failed: %v\n%s", strings.Join(cmd.Args, " "), err, cmd.Stderr)
}
// Output looks like:
//
// SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
// SSH_AGENT_PID=15542; export SSH_AGENT_PID;
// echo Agent pid 15542;
fields := bytes.Split(out, []byte(";"))
line := bytes.SplitN(fields[0], []byte("="), 2)
line[0] = bytes.TrimLeft(line[0], "\n")
if string(line[0]) != "SSH_AUTH_SOCK" {
t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
}
socket = string(line[1])
line = bytes.SplitN(fields[2], []byte("="), 2)
line[0] = bytes.TrimLeft(line[0], "\n")
if string(line[0]) != "SSH_AGENT_PID" {
t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
}
pidStr := line[1]
pid, err := strconv.Atoi(string(pidStr))
if err != nil {
t.Fatalf("Atoi(%q): %v", pidStr, err)
}
conn, err := net.Dial("unix", string(socket))
if err != nil {
t.Fatalf("net.Dial: %v", err)
}
ac := agent.NewClient(conn)
return ac, socket, func() {
proc, _ := os.FindProcess(pid)
if proc != nil {
proc.Kill()
}
conn.Close()
os.RemoveAll(filepath.Dir(socket))
}
}
func startAgent(t *testing.T, sshagent agent.Agent) (client agent.Agent, cleanup func()) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
go agent.ServeAgent(sshagent, c2)
return agent.NewClient(c1), func() {
c1.Close()
c2.Close()
}
}
// startKeyringAgent uses Keyring to simulate a ssh-agent Server and returns a client.
func startKeyringAgent(t *testing.T) (client agent.Agent, cleanup func()) {
return startAgent(t, agent.NewKeyring())
}
type startTestAgentFunc func(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent)
func startTestOpenSSHAgent(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent) {
sshagent, _, cleanup := startOpenSSHAgent(t)
for _, keyToAdd := range keysToAdd {
err := sshagent.Add(keyToAdd)
if err != nil {
t.Fatalf("sshagent.add: %v", err)
}
}
t.Cleanup(cleanup)
//testAgentInterface(t, sshagent, key, cert, lifetimeSecs)
return sshagent
}
func startTestKeyringAgent(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent) {
sshagent, cleanup := startKeyringAgent(t)
for _, keyToAdd := range keysToAdd {
err := sshagent.Add(keyToAdd)
if err != nil {
t.Fatalf("sshagent.add: %v", err)
}
}
t.Cleanup(cleanup)
//testAgentInterface(t, agent, key, cert, lifetimeSecs)
return sshagent
}
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
listener, err := netListener()
if err != nil {
return nil, nil, err
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
return nil, nil, err
}
c2, err := listener.Accept()
if err != nil {
c1.Close()
return nil, nil, err
}
return c1, c2, nil
}
// netListener creates a localhost network listener.
func netListener() (net.Listener, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
listener, err = net.Listen("tcp", "[::1]:0")
if err != nil {
return nil, err
}
}
return listener, nil
}
func TestNew(t *testing.T) {
comment := "Key from OpenSSHAgent"
// Ensure we don't "inherit" any SSH_AUTH_SOCK
os.Unsetenv("SSH_AUTH_SOCK")
sshagent, socket, cleanup := startOpenSSHAgent(t)
os.Setenv("SSH_AUTH_SOCK", socket)
t.Cleanup(func() {
os.Unsetenv("SSH_AUTH_SOCK")
cleanup()
})
// Test that we can't find any signers in the agent before we have loaded them
t.Run("No keys with OpenSSHAgent", func(t *testing.T) {
kms, err := New(context.Background(), apiv1.Options{})
if kms == nil || err != nil {
t.Errorf("New() = %v, %v", kms, err)
}
signer, err := kms.CreateSigner(&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment})
if err == nil || signer != nil {
t.Errorf("SSHAgentKMS.CreateSigner() error = \"%v\", signer = \"%v\"", err, signer)
}
})
// Load ssh test fixtures
b, err := os.ReadFile("testdata/ssh")
if err != nil {
t.Fatal(err)
}
privateKey, err := ssh.ParseRawPrivateKey(b)
if err != nil {
t.Fatal(err)
}
// And add that key to the agent
err = sshagent.Add(agent.AddedKey{PrivateKey: privateKey, Comment: comment})
if err != nil {
t.Fatalf("sshagent.add: %v", err)
}
// And test that we can find it when it's loaded
t.Run("Keys with OpenSSHAgent", func(t *testing.T) {
kms, err := New(context.Background(), apiv1.Options{})
if kms == nil || err != nil {
t.Errorf("New() = %v, %v", kms, err)
}
signer, err := kms.CreateSigner(&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment})
if err != nil || signer == nil {
t.Errorf("SSHAgentKMS.CreateSigner() error = \"%v\", signer = \"%v\"", err, signer)
}
})
}
func TestNewFromAgent(t *testing.T) {
type args struct {
ctx context.Context
opts apiv1.Options
}
tests := []struct {
name string
args args
sshagentstarter startTestAgentFunc
wantErr bool
}{
{"ok OpenSSHAgent", args{context.Background(), apiv1.Options{}}, startTestOpenSSHAgent, false},
{"ok KeyringAgent", args{context.Background(), apiv1.Options{}}, startTestKeyringAgent, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewFromAgent(tt.args.ctx, tt.args.opts, tt.sshagentstarter(t))
if (err != nil) != tt.wantErr {
t.Errorf("NewFromAgent() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got == nil {
t.Errorf("NewFromAgent() = %v", got)
}
})
}
}
func TestSSHAgentKMS_Close(t *testing.T) {
tests := []struct {
name string
wantErr bool
}{
{"ok", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &SSHAgentKMS{}
if err := k.Close(); (err != nil) != tt.wantErr {
t.Errorf("SSHAgentKMS.Close() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestSSHAgentKMS_CreateSigner(t *testing.T) {
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
pemBlock, err := pemutil.Serialize(pk)
if err != nil {
t.Fatal(err)
}
pemBlockPassword, err := pemutil.Serialize(pk, pemutil.WithPassword([]byte("pass")))
if err != nil {
t.Fatal(err)
}
// Read and decode file using standard packages
b, err := os.ReadFile("testdata/priv.pem")
if err != nil {
t.Fatal(err)
}
block, _ := pem.Decode(b)
block.Bytes, err = x509.DecryptPEMBlock(block, []byte("pass")) //nolint
if err != nil {
t.Fatal(err)
}
pk2, err := x509.ParseECPrivateKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
// Create a public PEM
b, err = x509.MarshalPKIXPublicKey(pk.Public())
if err != nil {
t.Fatal(err)
}
pub := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: b,
})
// Load ssh test fixtures
sshPubKeyStr, err := os.ReadFile("testdata/ssh.pub")
if err != nil {
t.Fatal(err)
}
_, comment, _, _, err := ssh.ParseAuthorizedKey(sshPubKeyStr)
if err != nil {
t.Fatal(err)
}
b, err = os.ReadFile("testdata/ssh")
if err != nil {
t.Fatal(err)
}
privateKey, err := ssh.ParseRawPrivateKey(b)
if err != nil {
t.Fatal(err)
}
sshPrivateKey, err := ssh.NewSignerFromKey(privateKey)
if err != nil {
t.Fatal(err)
}
wrappedSSHPrivateKey := NewWrappedSignerFromSSHSigner(sshPrivateKey)
type args struct {
req *apiv1.CreateSignerRequest
}
tests := []struct {
name string
args args
want crypto.Signer
wantErr bool
}{
{"signer", args{&apiv1.CreateSignerRequest{Signer: pk}}, pk, false},
{"pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlock)}}, pk, false},
{"pem password", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlockPassword), Password: []byte("pass")}}, pk, false},
{"file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("pass")}}, pk2, false},
{"sshagent", args{&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment}}, wrappedSSHPrivateKey, false},
{"sshagent Nonexistant", args{&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:Nonexistant"}}, nil, true},
{"fail", args{&apiv1.CreateSignerRequest{}}, nil, true},
{"fail bad pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: []byte("bad pem")}}, nil, true},
{"fail bad password", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("bad-pass")}}, nil, true},
{"fail not a signer", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pub}}, nil, true},
{"fail not a signer from file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/pub.pem"}}, nil, true},
{"fail missing", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/missing"}}, nil, true},
}
starters := []struct {
name string
starter startTestAgentFunc
}{
{"startTestOpenSSHAgent", startTestOpenSSHAgent},
{"startTestKeyringAgent", startTestKeyringAgent},
}
for _, starter := range starters {
k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t, agent.AddedKey{PrivateKey: privateKey, Comment: comment}))
if err != nil {
t.Fatal(err)
}
for _, tt := range tests {
t.Run(starter.name+"/"+tt.name, func(t *testing.T) {
got, err := k.CreateSigner(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("SSHAgentKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
return
}
// nolint:gocritic
switch s := got.(type) {
case *WrappedSSHSigner:
gotPkS := s.Sshsigner.PublicKey().(*agent.Key).String() + "\n"
wantPkS := string(sshPubKeyStr)
if !reflect.DeepEqual(gotPkS, wantPkS) {
t.Errorf("SSHAgentKMS.CreateSigner() = %T, want %T", gotPkS, wantPkS)
t.Errorf("SSHAgentKMS.CreateSigner() = %v, want %v", gotPkS, wantPkS)
}
default:
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SSHAgentKMS.CreateSigner() = %T, want %T", got, tt.want)
t.Errorf("SSHAgentKMS.CreateSigner() = %v, want %v", got, tt.want)
}
}
})
}
}
}
/*
func restoreGenerateKey() func() {
oldGenerateKey := generateKey
return func() {
generateKey = oldGenerateKey
}
}
*/
/*
func TestSSHAgentKMS_CreateKey(t *testing.T) {
fn := restoreGenerateKey()
defer fn()
p256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
edpub, edpriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
type args struct {
req *apiv1.CreateKeyRequest
}
type params struct {
kty string
crv string
size int
}
tests := []struct {
name string
args args
generateKey func() (interface{}, interface{}, error)
want *apiv1.CreateKeyResponse
wantParams params
wantErr bool
}{
{"p256", args{&apiv1.CreateKeyRequest{Name: "p256", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
return p256.Public(), p256, nil
}, &apiv1.CreateKeyResponse{Name: "p256", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
{"rsa", args{&apiv1.CreateKeyRequest{Name: "rsa3072", SignatureAlgorithm: apiv1.SHA256WithRSA}}, func() (interface{}, interface{}, error) {
return rsa2048.Public(), rsa2048, nil
}, &apiv1.CreateKeyResponse{Name: "rsa3072", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 0}, false},
{"rsa2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 2048}}, func() (interface{}, interface{}, error) {
return rsa2048.Public(), rsa2048, nil
}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
{"rsaPSS2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSAPSS, Bits: 2048}}, func() (interface{}, interface{}, error) {
return rsa2048.Public(), rsa2048, nil
}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
{"ed25519", args{&apiv1.CreateKeyRequest{Name: "ed25519", SignatureAlgorithm: apiv1.PureEd25519}}, func() (interface{}, interface{}, error) {
return edpub, edpriv, nil
}, &apiv1.CreateKeyResponse{Name: "ed25519", PublicKey: edpub, PrivateKey: edpriv, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: edpriv}}, params{"OKP", "Ed25519", 0}, false},
{"default", args{&apiv1.CreateKeyRequest{Name: "default"}}, func() (interface{}, interface{}, error) {
return p256.Public(), p256, nil
}, &apiv1.CreateKeyResponse{Name: "default", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
{"fail algorithm", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.SignatureAlgorithm(100)}}, func() (interface{}, interface{}, error) {
return p256.Public(), p256, nil
}, nil, params{}, true},
{"fail generate key", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
return nil, nil, fmt.Errorf("an error")
}, nil, params{"EC", "P-256", 0}, true},
{"fail no signer", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
return 1, 2, nil
}, nil, params{"EC", "P-256", 0}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &SSHAgentKMS{}
generateKey = func(kty, crv string, size int) (interface{}, interface{}, error) {
if tt.wantParams.kty != kty {
t.Errorf("GenerateKey() kty = %s, want %s", kty, tt.wantParams.kty)
}
if tt.wantParams.crv != crv {
t.Errorf("GenerateKey() crv = %s, want %s", crv, tt.wantParams.crv)
}
if tt.wantParams.size != size {
t.Errorf("GenerateKey() size = %d, want %d", size, tt.wantParams.size)
}
return tt.generateKey()
}
got, err := k.CreateKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("SSHAgentKMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SSHAgentKMS.CreateKey() = %v, want %v", got, tt.want)
}
})
}
}
*/
func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
b, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
block, _ := pem.Decode(b)
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
// Load ssh test fixtures
b, err = os.ReadFile("testdata/ssh.pub")
if err != nil {
t.Fatal(err)
}
sshPubKey, comment, _, _, err := ssh.ParseAuthorizedKey(b)
if err != nil {
t.Fatal(err)
}
b, err = os.ReadFile("testdata/ssh")
if err != nil {
t.Fatal(err)
}
// crypto.PrivateKey
sshPrivateKey, err := ssh.ParseRawPrivateKey(b)
if err != nil {
t.Fatal(err)
}
type args struct {
req *apiv1.GetPublicKeyRequest
}
tests := []struct {
name string
args args
want crypto.PublicKey
wantErr bool
}{
{"key", args{&apiv1.GetPublicKeyRequest{Name: "testdata/pub.pem"}}, pub, false},
{"cert", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.crt"}}, pub, false},
{"sshagent", args{&apiv1.GetPublicKeyRequest{Name: "sshagentkms:" + comment}}, sshPubKey, false},
{"sshagent Nonexistant", args{&apiv1.GetPublicKeyRequest{Name: "sshagentkms:Nonexistant"}}, nil, true},
{"fail not exists", args{&apiv1.GetPublicKeyRequest{Name: "testdata/missing"}}, nil, true},
{"fail type", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.key"}}, nil, true},
}
starters := []struct {
name string
starter startTestAgentFunc
}{
{"startTestOpenSSHAgent", startTestOpenSSHAgent},
{"startTestKeyringAgent", startTestKeyringAgent},
}
for _, starter := range starters {
k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t, agent.AddedKey{PrivateKey: sshPrivateKey, Comment: comment}))
if err != nil {
t.Fatal(err)
}
for _, tt := range tests {
t.Run(starter.name+"/"+tt.name, func(t *testing.T) {
got, err := k.GetPublicKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("SSHAgentKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
// nolint:gocritic
switch tt.want.(type) {
case ssh.PublicKey:
// If we want a ssh.PublicKey, protote got to a
got, err = ssh.NewPublicKey(got)
if err != nil {
t.Fatal(err)
}
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SSHAgentKMS.GetPublicKey() = %T, want %T", got, tt.want)
t.Errorf("SSHAgentKMS.GetPublicKey() = %v, want %v", got, tt.want)
}
})
}
}
}
func TestSSHAgentKMS_CreateKey(t *testing.T) {
starters := []struct {
name string
starter startTestAgentFunc
}{
{"startTestOpenSSHAgent", startTestOpenSSHAgent},
{"startTestKeyringAgent", startTestKeyringAgent},
}
for _, starter := range starters {
k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t))
if err != nil {
t.Fatal(err)
}
t.Run(starter.name+"/CreateKey", func(t *testing.T) {
got, err := k.CreateKey(&apiv1.CreateKeyRequest{
Name: "sshagentkms:0",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
})
if got != nil {
t.Error("SSHAgentKMS.CreateKey() shoudn't return a value")
}
if err == nil {
t.Error("SSHAgentKMS.CreateKey() didn't return a value")
}
})
}
}

View file

@ -1,11 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIBpzCCAU2gAwIBAgIQWaY8KIDAfak8aYljelf8eTAKBggqhkjOPQQDAjAdMRsw
GQYDVQQDExJ0ZXN0LnNtYWxsc3RlcC5jb20wHhcNMjAwMTE2MDAwNDU4WhcNMjAw
MTE3MDAwNDU4WjAdMRswGQYDVQQDExJ0ZXN0LnNtYWxsc3RlcC5jb20wWTATBgcq
hkjOPQIBBggqhkjOPQMBBwNCAATlU8P9blFefSWuzYx2g215NJn6yHW95PXeFqQ9
kX1jNo1VmC6Oord3We37iM8QJT4QP9ZDUaAVmJUZSjd+W8H/o28wbTAOBgNVHQ8B
Af8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQW
BBTn0wonKkm2lLRNYZrKhUukiynvqzAdBgNVHREEFjAUghJ0ZXN0LnNtYWxsc3Rl
cC5jb20wCgYIKoZIzj0EAwIDSAAwRQIhAJ5XqryBIY1X4fl/9l0isV69eQfA0Qo5
1mjervUcEnOWAiBsmN4frz5YVw7i4UXChVBeZLZfJOKvn5eyh2gEzoq1+w==
-----END CERTIFICATE-----

View file

@ -1,5 +0,0 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEICB6lIrMa9fVQJtdAYS4qmdYQ1BHJsEQDx8zxL38gA8toAoGCCqGSM49
AwEHoUQDQgAE5VPD/W5RXn0lrs2MdoNteTSZ+sh1veT13hakPZF9YzaNVZgujqK3
d1nt+4jPECU+ED/WQ1GgFZiVGUo3flvB/w==
-----END EC PRIVATE KEY-----

View file

@ -1,8 +0,0 @@
-----BEGIN EC PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-256-CBC,1fcec5dfbf3327f61bfe5ab6ae8a0626
V39b/pNHMbP80TXSHLsUY6UOTCzf3KwIxvj1e7S9brNMJJc9b3UiloMBJIYBkl00
NKI8JU4jSlcerR58DqsTHIELiX6a+RJLe3/iR2/5Gru+CmmWJ68jQu872WCgh6Ms
o8TzhyGx74ETmdKn5CdtylsnKMa9heW3tBLFAbNCgKc=
-----END EC PRIVATE KEY-----

View file

@ -1,4 +0,0 @@
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE5VPD/W5RXn0lrs2MdoNteTSZ+sh1
veT13hakPZF9YzaNVZgujqK3d1nt+4jPECU+ED/WQ1GgFZiVGUo3flvB/w==
-----END PUBLIC KEY-----

View file

@ -1,49 +0,0 @@
-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAACFwAAAAdzc2gtcn
NhAAAAAwEAAQAAAgEAth/d7zRDbv567o46KT6YYqC/EVdDpZ8m0rzIdroJL+RHVDXNQ1pU
3lrC9IWfkyjX+YwO9jHGbraJ+CgonAkl36mtLzNC4645QGS2/WdFqRR6mQCz7v4G6nOaFN
SCeErMhg0fn4f7jdqXpd0hYozIpktRVNYcpi2RMmr8e/Kadr5EVQfbYZgdKIl1O6Ws9O3Q
1BhLGi9GipEstUTvjqxZzF7oUgWKH54j5eHNXdbFqKqnK8NNQmypNLYGDsTBQHG9zRs+o0
7C2foO9ddIO2OCarcBWZfGlY05k/ZhEmrEOONh2rSLhJwqw+EJgQeU0Poe/IqjFy7jnTRk
i+tee2elBYVvHYPSofZaBmX7i21s8eBRl/ZiFx3ip6E3M54mXvKZ7SuA2qq/YW0IeKyJ5D
SuL0+sRAyiSQ2Icsyb3YKv6LXojuJTmJ9Hg9v4+aOPxOQhvNfh3b7sIh/cmz1dq/babLyO
ORrbHKDxIJME7VPMspmddV9wJgB4Gu1eWOiR/Cuv6jqYWTfiWJDIoqZRD5nF1tFqKtZ5iA
qkflv4Kbo10tv6nTlXR6TWuPu2Z/pZpx+NN+7QxVUSlRgxb7RTVcHRvpgd0TNEXGduR8ar
WVDlNewOmf5KFroW1IX/yR1OvE5RsDixxcX7Ne+uSlq9hooy9V/Ip0ffcF/Kg0NJoPwrnI
MAAAdQrAxluqwMZboAAAAHc3NoLXJzYQAAAgEAth/d7zRDbv567o46KT6YYqC/EVdDpZ8m
0rzIdroJL+RHVDXNQ1pU3lrC9IWfkyjX+YwO9jHGbraJ+CgonAkl36mtLzNC4645QGS2/W
dFqRR6mQCz7v4G6nOaFNSCeErMhg0fn4f7jdqXpd0hYozIpktRVNYcpi2RMmr8e/Kadr5E
VQfbYZgdKIl1O6Ws9O3Q1BhLGi9GipEstUTvjqxZzF7oUgWKH54j5eHNXdbFqKqnK8NNQm
ypNLYGDsTBQHG9zRs+o07C2foO9ddIO2OCarcBWZfGlY05k/ZhEmrEOONh2rSLhJwqw+EJ
gQeU0Poe/IqjFy7jnTRki+tee2elBYVvHYPSofZaBmX7i21s8eBRl/ZiFx3ip6E3M54mXv
KZ7SuA2qq/YW0IeKyJ5DSuL0+sRAyiSQ2Icsyb3YKv6LXojuJTmJ9Hg9v4+aOPxOQhvNfh
3b7sIh/cmz1dq/babLyOORrbHKDxIJME7VPMspmddV9wJgB4Gu1eWOiR/Cuv6jqYWTfiWJ
DIoqZRD5nF1tFqKtZ5iAqkflv4Kbo10tv6nTlXR6TWuPu2Z/pZpx+NN+7QxVUSlRgxb7RT
VcHRvpgd0TNEXGduR8arWVDlNewOmf5KFroW1IX/yR1OvE5RsDixxcX7Ne+uSlq9hooy9V
/Ip0ffcF/Kg0NJoPwrnIMAAAADAQABAAACADQ4KONYQemGT+ssnqKKzxigbIhlVAEeA/yy
omvgZZf0xTrw/jzMnr7umS2RTrLcKCjmLrgKh5HhBug/Y31x5gkeVojNEuXDY6kB97HqtX
+IXqqWGAFzlroMkWZdlFc3YzMgeiu8yrTes1Kcd+EQ6ss7l0NS7P383L/vCxvi8MURQvh6
ez2dZubjmtiSZWgI9DKMEKSeX4SFoaML9AAdjNXbdJNoATWVm0djmgXI+f2liK80nWdpTo
7NjikX4y0+L6SqpigfAiGL4FQ++PgGTTOZ62or6YWh65twLl8ge8iv8bPKxqIsQNrPIHF9
of7VaKMSgTa5fAvsJNQ1lW6exiK1szJ+g+zrkHuOjDaEWyIZi24/xy6iDaT1sdcjTGPJAo
WqgC9hlZQKjOOZJgwqu/kxgcsOGaGb2MD/E4xJVMvPsWYLQ5WGdiakQkVhclpcr3e0d8nw
xvqCqLsasCSECKJK+k3ReqtOe6GlTSzIpFiOgFAuYp+ejRkX6bJ2DRaYkjoWWza2VCpIJC
uyK7B3r1cV+g5KzvT6B+7TxVqYERisjWNvdppF87Vtx7C0p8mDzpJYpPY+yao3vEcq104+
yXuaPGEDTkTWOUB2uUS+AD9CBjkrGYFab1DBJob+L/7jNgVgWmMw1Yj9SDwXO6YBfbkhCf
Irfmf9Ne5i1+2SpFWBAAABAQCud97O9xI2bMGVGfbDFiaPTYGaGZ0qurLtHPpCX/YFkdBh
Z3LG7psJ/4JhkmMI3RFGhMxpUR9K22T3P/UmUt01PrDwDUpcw1JRPVIGs9AV3+GsAyyE6X
MzYo+8LNcxaPjh6ECXAQLcd9g0NOCbiqrKURBEuIBkxTy8jsmmeUlDsLcs8QKCsObJ2ozO
ACuFG5Z/SUeB7nhHnRUnozE8KsEWAgpys37AnJc1cQR6ALloh23L46rsWbSN5UGRgZdaUo
tklsDRun3qtYkDC8dDbW2Iy5A7GUXBRIA3mDYf4GDEUQvuu5Q/A2Dsr0hVi2wNVWd5O5M0
NVhuCHJU355wbbUUAAABAQDuet4GZQImmqfj2xAMoHUfSK0WagtzynP2fOSIRtOKQ9UXJN
J1CrSeu93dNACYjXt10X5ZCdZ9x/75ltyZHSUBbT1eQzPD4Jq23EcJ9ECCc4tJMpdNpJyv
8ixfeTCX0m6XP7nDDLgkuYuNTj/NTqIWotHt8/R8BA9FfTchZE+ekqj3TTIac3buU294mO
/0KKGHtt+GPHSD+ES+W28KETiFcz5nSD7oUQPXEbvsJg5bOWt9kY6JBGiizJSsEuLIjcva
H3UQMx6U805NjoGwIiKJyKgcmDMWVbeH87XxV6sllE8UaLUxbcOBdhmF/uJlazQsbqmF7B
CJB/X7SXredw9BAAABAQDDgRzgXsvBH72PMetQpWGswXp6UVsdHUUEyDiJXc5xjiVOxAIw
+pwaBRQ/6WMMJvhpZ/IFN+pAYEW5e0q2eGMpc1or4kf5eTukwJSF6VZf1Hhti6TfiStPCf
KSz07jUFROahMC88BOSwHuCc66emWlsZDrXS+pht1O7yU96epTM/hT/e8Bfi+ZFCJnQoQ5
dZuONhOYUT32rFKGBwPhsi6pjMB54vqrW1xFJbwj4i4dHFzA7UUa79j7ToAs2g2q8odTCR
CLUxGJ+YOkti67taOuRbzlL9wlxLGT+G2Dai9Ymbt18rmXR+2vazE0xFigYHPZb2QXeLAS
u104cC7ouX7DAAAAFnNzaC50ZXN0LnNtYWxsc3RlcC5jb20BAgME
-----END OPENSSH PRIVATE KEY-----

View file

@ -1 +0,0 @@
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQC2H93vNENu/nrujjopPphioL8RV0OlnybSvMh2ugkv5EdUNc1DWlTeWsL0hZ+TKNf5jA72McZuton4KCicCSXfqa0vM0LjrjlAZLb9Z0WpFHqZALPu/gbqc5oU1IJ4SsyGDR+fh/uN2pel3SFijMimS1FU1hymLZEyavx78pp2vkRVB9thmB0oiXU7paz07dDUGEsaL0aKkSy1RO+OrFnMXuhSBYofniPl4c1d1sWoqqcrw01CbKk0tgYOxMFAcb3NGz6jTsLZ+g7110g7Y4JqtwFZl8aVjTmT9mESasQ442HatIuEnCrD4QmBB5TQ+h78iqMXLuOdNGSL6157Z6UFhW8dg9Kh9loGZfuLbWzx4FGX9mIXHeKnoTczniZe8pntK4Daqr9hbQh4rInkNK4vT6xEDKJJDYhyzJvdgq/oteiO4lOYn0eD2/j5o4/E5CG81+HdvuwiH9ybPV2r9tpsvI45GtscoPEgkwTtU8yymZ11X3AmAHga7V5Y6JH8K6/qOphZN+JYkMiiplEPmcXW0Woq1nmICqR+W/gpujXS2/qdOVdHpNa4+7Zn+lmnH4037tDFVRKVGDFvtFNVwdG+mB3RM0RcZ25HxqtZUOU17A6Z/koWuhbUhf/JHU68TlGwOLHFxfs1765KWr2GijL1X8inR99wX8qDQ0mg/Cucgw== ssh.test.smallstep.com

View file

@ -1 +0,0 @@
trim-this-pin

View file

@ -1,148 +0,0 @@
package uri
import (
"bytes"
"encoding/hex"
"net/url"
"os"
"strings"
"unicode"
"github.com/pkg/errors"
)
// URI implements a parser for a URI format based on the the PKCS #11 URI Scheme
// defined in https://tools.ietf.org/html/rfc7512
//
// These URIs will be used to define the key names in a KMS.
type URI struct {
*url.URL
Values url.Values
}
// New creates a new URI from a scheme and key-value pairs.
func New(scheme string, values url.Values) *URI {
return &URI{
URL: &url.URL{
Scheme: scheme,
Opaque: strings.ReplaceAll(values.Encode(), "&", ";"),
},
Values: values,
}
}
// NewFile creates an uri for a file.
func NewFile(path string) *URI {
return &URI{
URL: &url.URL{
Scheme: "file",
Path: path,
},
}
}
// HasScheme returns true if the given uri has the given scheme, false otherwise.
func HasScheme(scheme, rawuri string) bool {
u, err := url.Parse(rawuri)
if err != nil {
return false
}
return strings.EqualFold(u.Scheme, scheme)
}
// Parse returns the URI for the given string or an error.
func Parse(rawuri string) (*URI, error) {
u, err := url.Parse(rawuri)
if err != nil {
return nil, errors.Wrapf(err, "error parsing %s", rawuri)
}
if u.Scheme == "" {
return nil, errors.Errorf("error parsing %s: scheme is missing", rawuri)
}
// Starting with Go 1.17 url.ParseQuery returns an error using semicolon as
// separator.
v, err := url.ParseQuery(strings.ReplaceAll(u.Opaque, ";", "&"))
if err != nil {
return nil, errors.Wrapf(err, "error parsing %s", rawuri)
}
return &URI{
URL: u,
Values: v,
}, nil
}
// ParseWithScheme returns the URI for the given string only if it has the given
// scheme.
func ParseWithScheme(scheme, rawuri string) (*URI, error) {
u, err := Parse(rawuri)
if err != nil {
return nil, err
}
if !strings.EqualFold(u.Scheme, scheme) {
return nil, errors.Errorf("error parsing %s: scheme not expected", rawuri)
}
return u, nil
}
// Get returns the first value in the uri with the given key, it will return
// empty string if that field is not present.
func (u *URI) Get(key string) string {
v := u.Values.Get(key)
if v == "" {
v = u.URL.Query().Get(key)
}
return v
}
// GetBool returns true if a given key has the value "true". It returns false
// otherwise.
func (u *URI) GetBool(key string) bool {
v := u.Values.Get(key)
if v == "" {
v = u.URL.Query().Get(key)
}
return strings.EqualFold(v, "true")
}
// GetEncoded returns the first value in the uri with the given key, it will
// return empty nil if that field is not present or is empty. If the return
// value is hex encoded it will decode it and return it.
func (u *URI) GetEncoded(key string) []byte {
v := u.Get(key)
if v == "" {
return nil
}
if len(v)%2 == 0 {
if b, err := hex.DecodeString(v); err == nil {
return b
}
}
return []byte(v)
}
// Pin returns the pin encoded in the url. It will read the pin from the
// pin-value or the pin-source attributes.
func (u *URI) Pin() string {
if value := u.Get("pin-value"); value != "" {
return value
}
if path := u.Get("pin-source"); path != "" {
if b, err := readFile(path); err == nil {
return string(bytes.TrimRightFunc(b, unicode.IsSpace))
}
}
return ""
}
func readFile(path string) ([]byte, error) {
u, err := url.Parse(path)
if err == nil && (u.Scheme == "" || u.Scheme == "file") && u.Path != "" {
path = u.Path
}
b, err := os.ReadFile(path)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", path)
}
return b, nil
}

View file

@ -1,62 +0,0 @@
//go:build go1.19
package uri
import (
"net/url"
"reflect"
"testing"
)
func TestParse(t *testing.T) {
type args struct {
rawuri string
}
tests := []struct {
name string
args args
want *URI
wantErr bool
}{
{"ok", args{"yubikey:slot-id=9a"}, &URI{
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a"},
Values: url.Values{"slot-id": []string{"9a"}},
}, false},
{"ok schema", args{"cloudkms:"}, &URI{
URL: &url.URL{Scheme: "cloudkms"},
Values: url.Values{},
}, false},
{"ok query", args{"yubikey:slot-id=9a;foo=bar?pin=123456&foo=bar"}, &URI{
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a;foo=bar", RawQuery: "pin=123456&foo=bar"},
Values: url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}},
}, false},
{"ok file", args{"file:///tmp/ca.cert"}, &URI{
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"},
Values: url.Values{},
}, false},
{"ok file simple", args{"file:/tmp/ca.cert"}, &URI{
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert", OmitHost: true},
Values: url.Values{},
}, false},
{"ok file host", args{"file://tmp/ca.cert"}, &URI{
URL: &url.URL{Scheme: "file", Host: "tmp", Path: "/ca.cert"},
Values: url.Values{},
}, false},
{"fail schema", args{"cloudkms"}, nil, true},
{"fail parse", args{"yubi%key:slot-id=9a"}, nil, true},
{"fail scheme", args{"yubikey"}, nil, true},
{"fail parse opaque", args{"yubikey:slot-id=%ZZ"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Parse(tt.args.rawuri)
if (err != nil) != tt.wantErr {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Parse() = %#v, want %#v", got.URL, tt.want.URL)
}
})
}
}

View file

@ -1,62 +0,0 @@
//go:build !go1.19
package uri
import (
"net/url"
"reflect"
"testing"
)
func TestParse(t *testing.T) {
type args struct {
rawuri string
}
tests := []struct {
name string
args args
want *URI
wantErr bool
}{
{"ok", args{"yubikey:slot-id=9a"}, &URI{
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a"},
Values: url.Values{"slot-id": []string{"9a"}},
}, false},
{"ok schema", args{"cloudkms:"}, &URI{
URL: &url.URL{Scheme: "cloudkms"},
Values: url.Values{},
}, false},
{"ok query", args{"yubikey:slot-id=9a;foo=bar?pin=123456&foo=bar"}, &URI{
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a;foo=bar", RawQuery: "pin=123456&foo=bar"},
Values: url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}},
}, false},
{"ok file", args{"file:///tmp/ca.cert"}, &URI{
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"},
Values: url.Values{},
}, false},
{"ok file simple", args{"file:/tmp/ca.cert"}, &URI{
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"},
Values: url.Values{},
}, false},
{"ok file host", args{"file://tmp/ca.cert"}, &URI{
URL: &url.URL{Scheme: "file", Host: "tmp", Path: "/ca.cert"},
Values: url.Values{},
}, false},
{"fail schema", args{"cloudkms"}, nil, true},
{"fail parse", args{"yubi%key:slot-id=9a"}, nil, true},
{"fail scheme", args{"yubikey"}, nil, true},
{"fail parse opaque", args{"yubikey:slot-id=%ZZ"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Parse(tt.args.rawuri)
if (err != nil) != tt.wantErr {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Parse() = %#v, want %#v", got.URL, tt.want.URL)
}
})
}
}

View file

@ -1,282 +0,0 @@
package uri
import (
"net/url"
"reflect"
"testing"
)
func TestNew(t *testing.T) {
type args struct {
scheme string
values url.Values
}
tests := []struct {
name string
args args
want *URI
}{
{"ok", args{"yubikey", url.Values{"slot-id": []string{"9a"}}}, &URI{
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a"},
Values: url.Values{"slot-id": []string{"9a"}},
}},
{"ok multiple", args{"yubikey", url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}}}, &URI{
URL: &url.URL{Scheme: "yubikey", Opaque: "foo=bar;slot-id=9a"},
Values: url.Values{
"slot-id": []string{"9a"},
"foo": []string{"bar"},
},
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := New(tt.args.scheme, tt.args.values); !reflect.DeepEqual(got, tt.want) {
t.Errorf("New() = %v, want %v", got, tt.want)
}
})
}
}
func TestNewFile(t *testing.T) {
type args struct {
path string
}
tests := []struct {
name string
args args
want *URI
}{
{"ok", args{"/tmp/ca.crt"}, &URI{
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.crt"},
Values: url.Values(nil),
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewFile(tt.args.path); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewFile() = %v, want %v", got, tt.want)
}
})
}
}
func TestHasScheme(t *testing.T) {
type args struct {
scheme string
rawuri string
}
tests := []struct {
name string
args args
want bool
}{
{"ok", args{"yubikey", "yubikey:slot-id=9a"}, true},
{"ok empty", args{"yubikey", "yubikey:"}, true},
{"ok letter case", args{"awsKMS", "AWSkms:key-id=abcdefg?foo=bar"}, true},
{"fail", args{"yubikey", "awskms:key-id=abcdefg"}, false},
{"fail parse", args{"yubikey", "yubi%key:slot-id=9a"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := HasScheme(tt.args.scheme, tt.args.rawuri); got != tt.want {
t.Errorf("HasScheme() = %v, want %v", got, tt.want)
}
})
}
}
func TestParseWithScheme(t *testing.T) {
type args struct {
scheme string
rawuri string
}
tests := []struct {
name string
args args
want *URI
wantErr bool
}{
{"ok", args{"yubikey", "yubikey:slot-id=9a"}, &URI{
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a"},
Values: url.Values{"slot-id": []string{"9a"}},
}, false},
{"ok schema", args{"cloudkms", "cloudkms:"}, &URI{
URL: &url.URL{Scheme: "cloudkms"},
Values: url.Values{},
}, false},
{"ok file", args{"file", "file:///tmp/ca.cert"}, &URI{
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"},
Values: url.Values{},
}, false},
{"fail parse", args{"yubikey", "yubikey"}, nil, true},
{"fail scheme", args{"yubikey", "awskms:slot-id=9a"}, nil, true},
{"fail schema", args{"cloudkms", "cloudkms"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseWithScheme(tt.args.scheme, tt.args.rawuri)
if (err != nil) != tt.wantErr {
t.Errorf("ParseWithScheme() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ParseWithScheme() = %v, want %v", got, tt.want)
}
})
}
}
func TestURI_Get(t *testing.T) {
mustParse := func(s string) *URI {
u, err := Parse(s)
if err != nil {
t.Fatal(err)
}
return u
}
type args struct {
key string
}
tests := []struct {
name string
uri *URI
args args
want string
}{
{"ok", mustParse("yubikey:slot-id=9a"), args{"slot-id"}, "9a"},
{"ok first", mustParse("yubikey:slot-id=9a;slot-id=9b"), args{"slot-id"}, "9a"},
{"ok multiple", mustParse("yubikey:slot-id=9a;foo=bar"), args{"foo"}, "bar"},
{"ok in query", mustParse("yubikey:slot-id=9a?foo=bar"), args{"foo"}, "bar"},
{"fail missing", mustParse("yubikey:slot-id=9a"), args{"foo"}, ""},
{"fail missing query", mustParse("yubikey:slot-id=9a?bar=zar"), args{"foo"}, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.uri.Get(tt.args.key); got != tt.want {
t.Errorf("URI.Get() = %v, want %v", got, tt.want)
}
})
}
}
func TestURI_GetBool(t *testing.T) {
mustParse := func(s string) *URI {
u, err := Parse(s)
if err != nil {
t.Fatal(err)
}
return u
}
type args struct {
key string
}
tests := []struct {
name string
uri *URI
args args
want bool
}{
{"true", mustParse("azurekms:name=foo;vault=bar;hsm=true"), args{"hsm"}, true},
{"TRUE", mustParse("azurekms:name=foo;vault=bar;hsm=TRUE"), args{"hsm"}, true},
{"tRUe query", mustParse("azurekms:name=foo;vault=bar?hsm=tRUe"), args{"hsm"}, true},
{"false", mustParse("azurekms:name=foo;vault=bar;hsm=false"), args{"hsm"}, false},
{"false query", mustParse("azurekms:name=foo;vault=bar?hsm=false"), args{"hsm"}, false},
{"empty", mustParse("azurekms:name=foo;vault=bar;hsm=?bar=true"), args{"hsm"}, false},
{"missing", mustParse("azurekms:name=foo;vault=bar"), args{"hsm"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.uri.GetBool(tt.args.key); got != tt.want {
t.Errorf("URI.GetBool() = %v, want %v", got, tt.want)
}
})
}
}
func TestURI_GetEncoded(t *testing.T) {
mustParse := func(s string) *URI {
u, err := Parse(s)
if err != nil {
t.Fatal(err)
}
return u
}
type args struct {
key string
}
tests := []struct {
name string
uri *URI
args args
want []byte
}{
{"ok", mustParse("yubikey:slot-id=9a"), args{"slot-id"}, []byte{0x9a}},
{"ok first", mustParse("yubikey:slot-id=9a9b;slot-id=9b"), args{"slot-id"}, []byte{0x9a, 0x9b}},
{"ok percent", mustParse("yubikey:slot-id=9a;foo=%9a%9b%9c"), args{"foo"}, []byte{0x9a, 0x9b, 0x9c}},
{"ok in query", mustParse("yubikey:slot-id=9a?foo=9a"), args{"foo"}, []byte{0x9a}},
{"ok in query percent", mustParse("yubikey:slot-id=9a?foo=%9a"), args{"foo"}, []byte{0x9a}},
{"ok missing", mustParse("yubikey:slot-id=9a"), args{"foo"}, nil},
{"ok missing query", mustParse("yubikey:slot-id=9a?bar=zar"), args{"foo"}, nil},
{"ok no hex", mustParse("yubikey:slot-id=09a?bar=zar"), args{"slot-id"}, []byte{'0', '9', 'a'}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.uri.GetEncoded(tt.args.key)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("URI.GetEncoded() = %v, want %v", got, tt.want)
}
})
}
}
func TestURI_Pin(t *testing.T) {
mustParse := func(s string) *URI {
u, err := Parse(s)
if err != nil {
t.Fatal(err)
}
return u
}
tests := []struct {
name string
uri *URI
want string
}{
{"from value", mustParse("pkcs11:id=%72%73?pin-value=0123456789"), "0123456789"},
{"from source", mustParse("pkcs11:id=%72%73?pin-source=testdata/pin.txt"), "trim-this-pin"},
{"from missing", mustParse("pkcs11:id=%72%73"), ""},
{"from source missing", mustParse("pkcs11:id=%72%73?pin-source=testdata/foo.txt"), ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.uri.Pin(); got != tt.want {
t.Errorf("URI.Pin() = %v, want %v", got, tt.want)
}
})
}
}
func TestURI_String(t *testing.T) {
mustParse := func(s string) *URI {
u, err := Parse(s)
if err != nil {
t.Fatal(err)
}
return u
}
tests := []struct {
name string
uri *URI
want string
}{
{"ok new", New("yubikey", url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}}), "yubikey:foo=bar;slot-id=9a"},
{"ok parse", mustParse("yubikey:slot-id=9a;foo=bar?bar=zar"), "yubikey:slot-id=9a;foo=bar?bar=zar"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.uri.String(); got != tt.want {
t.Errorf("URI.String() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,322 +0,0 @@
//go:build cgo
// +build cgo
package yubikey
import (
"context"
"crypto"
"crypto/x509"
"encoding/hex"
"net/url"
"strings"
"github.com/go-piv/piv-go/piv"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/kms/uri"
)
// Scheme is the scheme used in uris.
const Scheme = "yubikey"
// YubiKey implements the KMS interface on a YubiKey.
type YubiKey struct {
yk *piv.YubiKey
pin string
managementKey [24]byte
}
// New initializes a new YubiKey.
// TODO(mariano): only one card is currently supported.
func New(ctx context.Context, opts apiv1.Options) (*YubiKey, error) {
managementKey := piv.DefaultManagementKey
if opts.URI != "" {
u, err := uri.ParseWithScheme(Scheme, opts.URI)
if err != nil {
return nil, err
}
if v := u.Pin(); v != "" {
opts.Pin = v
}
if v := u.Get("management-key"); v != "" {
opts.ManagementKey = v
}
}
// Deprecated way to set configuration parameters.
if opts.ManagementKey != "" {
b, err := hex.DecodeString(opts.ManagementKey)
if err != nil {
return nil, errors.Wrap(err, "error decoding managementKey")
}
if len(b) != 24 {
return nil, errors.New("invalid managementKey: length is not 24 bytes")
}
copy(managementKey[:], b[:24])
}
cards, err := piv.Cards()
if err != nil {
return nil, err
}
if len(cards) == 0 {
return nil, errors.New("error detecting yubikey: try removing and reconnecting the device")
}
yk, err := piv.Open(cards[0])
if err != nil {
return nil, errors.Wrap(err, "error opening yubikey")
}
return &YubiKey{
yk: yk,
pin: opts.Pin,
managementKey: managementKey,
}, nil
}
func init() {
apiv1.Register(apiv1.YubiKey, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
return New(ctx, opts)
})
}
// LoadCertificate implements kms.CertificateManager and loads a certificate
// from the YubiKey.
func (k *YubiKey) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) {
slot, err := getSlot(req.Name)
if err != nil {
return nil, err
}
cert, err := k.yk.Certificate(slot)
if err != nil {
return nil, errors.Wrap(err, "error retrieving certificate")
}
return cert, nil
}
// StoreCertificate implements kms.CertificateManager and stores a certificate
// in the YubiKey.
func (k *YubiKey) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
if req.Certificate == nil {
return errors.New("storeCertificateRequest 'Certificate' cannot be nil")
}
slot, err := getSlot(req.Name)
if err != nil {
return err
}
err = k.yk.SetCertificate(k.managementKey, slot, req.Certificate)
if err != nil {
return errors.Wrap(err, "error storing certificate")
}
return nil
}
// GetPublicKey returns the public key present in the YubiKey signature slot.
func (k *YubiKey) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
slot, err := getSlot(req.Name)
if err != nil {
return nil, err
}
pub, err := k.getPublicKey(slot)
if err != nil {
return nil, err
}
return pub, nil
}
// CreateKey generates a new key in the YubiKey and returns the public key.
func (k *YubiKey) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
alg, err := getSignatureAlgorithm(req.SignatureAlgorithm, req.Bits)
if err != nil {
return nil, err
}
slot, name, err := getSlotAndName(req.Name)
if err != nil {
return nil, err
}
pub, err := k.yk.GenerateKey(k.managementKey, slot, piv.Key{
Algorithm: alg,
PINPolicy: piv.PINPolicyAlways,
TouchPolicy: piv.TouchPolicyNever,
})
if err != nil {
return nil, errors.Wrap(err, "error generating key")
}
return &apiv1.CreateKeyResponse{
Name: name,
PublicKey: pub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: name,
},
}, nil
}
// CreateSigner creates a signer using the key present in the YubiKey signature
// slot.
func (k *YubiKey) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
slot, err := getSlot(req.SigningKey)
if err != nil {
return nil, err
}
pub, err := k.getPublicKey(slot)
if err != nil {
return nil, err
}
priv, err := k.yk.PrivateKey(slot, pub, piv.KeyAuth{
PIN: k.pin,
PINPolicy: piv.PINPolicyAlways,
})
if err != nil {
return nil, errors.Wrap(err, "error retrieving private key")
}
signer, ok := priv.(crypto.Signer)
if !ok {
return nil, errors.New("private key is not a crypto.Signer")
}
return signer, nil
}
// Close releases the connection to the YubiKey.
func (k *YubiKey) Close() error {
return errors.Wrap(k.yk.Close(), "error closing yubikey")
}
// getPublicKey returns the public key on a slot. First it attempts to do
// attestation to get a certificate with the public key in it, if this succeeds
// means that the key was generated in the device. If not we'll try to get the
// key from a stored certificate in the same slot.
func (k *YubiKey) getPublicKey(slot piv.Slot) (crypto.PublicKey, error) {
cert, err := k.yk.Attest(slot)
if err != nil {
if cert, err = k.yk.Certificate(slot); err != nil {
return nil, errors.Wrap(err, "error retrieving public key")
}
}
return cert.PublicKey, nil
}
// signatureAlgorithmMapping is a mapping between the step signature algorithm,
// and bits for RSA keys, with yubikey ones.
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]interface{}{
apiv1.UnspecifiedSignAlgorithm: piv.AlgorithmEC256,
apiv1.SHA256WithRSA: map[int]piv.Algorithm{
0: piv.AlgorithmRSA2048,
1024: piv.AlgorithmRSA1024,
2048: piv.AlgorithmRSA2048,
},
apiv1.SHA512WithRSA: map[int]piv.Algorithm{
0: piv.AlgorithmRSA2048,
1024: piv.AlgorithmRSA1024,
2048: piv.AlgorithmRSA2048,
},
apiv1.SHA256WithRSAPSS: map[int]piv.Algorithm{
0: piv.AlgorithmRSA2048,
1024: piv.AlgorithmRSA1024,
2048: piv.AlgorithmRSA2048,
},
apiv1.SHA512WithRSAPSS: map[int]piv.Algorithm{
0: piv.AlgorithmRSA2048,
1024: piv.AlgorithmRSA1024,
2048: piv.AlgorithmRSA2048,
},
apiv1.ECDSAWithSHA256: piv.AlgorithmEC256,
apiv1.ECDSAWithSHA384: piv.AlgorithmEC384,
}
func getSignatureAlgorithm(alg apiv1.SignatureAlgorithm, bits int) (piv.Algorithm, error) {
v, ok := signatureAlgorithmMapping[alg]
if !ok {
return 0, errors.Errorf("YubiKey does not support signature algorithm '%s'", alg)
}
switch v := v.(type) {
case piv.Algorithm:
return v, nil
case map[int]piv.Algorithm:
signatureAlgorithm, ok := v[bits]
if !ok {
return 0, errors.Errorf("YubiKey does not support signature algorithm '%s' with '%d' bits", alg, bits)
}
return signatureAlgorithm, nil
default:
return 0, errors.Errorf("unexpected error: this should not happen")
}
}
var slotMapping = map[string]piv.Slot{
"9a": piv.SlotAuthentication,
"9c": piv.SlotSignature,
"9e": piv.SlotCardAuthentication,
"9d": piv.SlotKeyManagement,
"82": {Key: 0x82, Object: 0x5FC10D},
"83": {Key: 0x83, Object: 0x5FC10E},
"84": {Key: 0x84, Object: 0x5FC10F},
"85": {Key: 0x85, Object: 0x5FC110},
"86": {Key: 0x86, Object: 0x5FC111},
"87": {Key: 0x87, Object: 0x5FC112},
"88": {Key: 0x88, Object: 0x5FC113},
"89": {Key: 0x89, Object: 0x5FC114},
"8a": {Key: 0x8a, Object: 0x5FC115},
"8b": {Key: 0x8b, Object: 0x5FC116},
"8c": {Key: 0x8c, Object: 0x5FC117},
"8d": {Key: 0x8d, Object: 0x5FC118},
"8e": {Key: 0x8e, Object: 0x5FC119},
"8f": {Key: 0x8f, Object: 0x5FC11A},
"90": {Key: 0x90, Object: 0x5FC11B},
"91": {Key: 0x91, Object: 0x5FC11C},
"92": {Key: 0x92, Object: 0x5FC11D},
"93": {Key: 0x93, Object: 0x5FC11E},
"94": {Key: 0x94, Object: 0x5FC11F},
"95": {Key: 0x95, Object: 0x5FC120},
}
func getSlot(name string) (piv.Slot, error) {
slot, _, err := getSlotAndName(name)
return slot, err
}
func getSlotAndName(name string) (piv.Slot, string, error) {
if name == "" {
return piv.SlotSignature, "yubikey:slot-id=9c", nil
}
var slotID string
name = strings.ToLower(name)
if strings.HasPrefix(name, "yubikey:") {
u, err := url.Parse(name)
if err != nil {
return piv.Slot{}, "", errors.Wrapf(err, "error parsing '%s'", name)
}
v, err := url.ParseQuery(u.Opaque)
if err != nil {
return piv.Slot{}, "", errors.Wrapf(err, "error parsing '%s'", name)
}
if slotID = v.Get("slot-id"); slotID == "" {
return piv.Slot{}, "", errors.Wrapf(err, "error parsing '%s': slot-id is missing", name)
}
} else {
slotID = name
}
s, ok := slotMapping[slotID]
if !ok {
return piv.Slot{}, "", errors.Errorf("unsupported slot-id '%s'", name)
}
name = "yubikey:slot-id=" + url.QueryEscape(slotID)
return s, name, nil
}

View file

@ -1,20 +0,0 @@
//go:build !cgo
// +build !cgo
package yubikey
import (
"context"
"os"
"path/filepath"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
)
func init() {
apiv1.Register(apiv1.YubiKey, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
name := filepath.Base(os.Args[0])
return nil, errors.Errorf("unsupported kms type 'yubikey': %s is compiled without cgo support", name)
})
}