forked from TrueCloudLab/certificates
Remove kms package
This commit is contained in:
parent
369b8f81c3
commit
4985ab1d62
57 changed files with 0 additions and 9040 deletions
|
@ -1,137 +0,0 @@
|
||||||
package apiv1
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/x509"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// KeyManager is the interface implemented by all the KMS.
|
|
||||||
type KeyManager interface {
|
|
||||||
GetPublicKey(req *GetPublicKeyRequest) (crypto.PublicKey, error)
|
|
||||||
CreateKey(req *CreateKeyRequest) (*CreateKeyResponse, error)
|
|
||||||
CreateSigner(req *CreateSignerRequest) (crypto.Signer, error)
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decrypter is an interface implemented by KMSes that are used
|
|
||||||
// in operations that require decryption
|
|
||||||
type Decrypter interface {
|
|
||||||
CreateDecrypter(req *CreateDecrypterRequest) (crypto.Decrypter, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CertificateManager is the interface implemented by the KMS that can load and
|
|
||||||
// store x509.Certificates.
|
|
||||||
type CertificateManager interface {
|
|
||||||
LoadCertificate(req *LoadCertificateRequest) (*x509.Certificate, error)
|
|
||||||
StoreCertificate(req *StoreCertificateRequest) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateName is an interface that KeyManager can implement to validate a
|
|
||||||
// given name or URI.
|
|
||||||
type NameValidator interface {
|
|
||||||
ValidateName(s string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrNotImplemented is the type of error returned if an operation is not
|
|
||||||
// implemented.
|
|
||||||
type ErrNotImplemented struct {
|
|
||||||
Message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e ErrNotImplemented) Error() string {
|
|
||||||
if e.Message != "" {
|
|
||||||
return e.Message
|
|
||||||
}
|
|
||||||
return "not implemented"
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrAlreadyExists is the type of error returned if a key already exists. This
|
|
||||||
// is currently only implmented on pkcs11.
|
|
||||||
type ErrAlreadyExists struct {
|
|
||||||
Message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e ErrAlreadyExists) Error() string {
|
|
||||||
if e.Message != "" {
|
|
||||||
return e.Message
|
|
||||||
}
|
|
||||||
return "key already exists"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type represents the KMS type used.
|
|
||||||
type Type string
|
|
||||||
|
|
||||||
const (
|
|
||||||
// DefaultKMS is a KMS implementation using software.
|
|
||||||
DefaultKMS Type = ""
|
|
||||||
// SoftKMS is a KMS implementation using software.
|
|
||||||
SoftKMS Type = "softkms"
|
|
||||||
// CloudKMS is a KMS implementation using Google's Cloud KMS.
|
|
||||||
CloudKMS Type = "cloudkms"
|
|
||||||
// AmazonKMS is a KMS implementation using Amazon AWS KMS.
|
|
||||||
AmazonKMS Type = "awskms"
|
|
||||||
// PKCS11 is a KMS implementation using the PKCS11 standard.
|
|
||||||
PKCS11 Type = "pkcs11"
|
|
||||||
// YubiKey is a KMS implementation using a YubiKey PIV.
|
|
||||||
YubiKey Type = "yubikey"
|
|
||||||
// SSHAgentKMS is a KMS implementation using ssh-agent to access keys.
|
|
||||||
SSHAgentKMS Type = "sshagentkms"
|
|
||||||
// AzureKMS is a KMS implementation using Azure Key Vault.
|
|
||||||
AzureKMS Type = "azurekms"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Options are the KMS options. They represent the kms object in the ca.json.
|
|
||||||
type Options struct {
|
|
||||||
// The type of the KMS to use.
|
|
||||||
Type string `json:"type"`
|
|
||||||
|
|
||||||
// Path to the credentials file used in CloudKMS and AmazonKMS.
|
|
||||||
CredentialsFile string `json:"credentialsFile,omitempty"`
|
|
||||||
|
|
||||||
// URI is based on the PKCS #11 URI Scheme defined in
|
|
||||||
// https://tools.ietf.org/html/rfc7512 and represents the configuration used
|
|
||||||
// to connect to the KMS.
|
|
||||||
//
|
|
||||||
// Used by: pkcs11
|
|
||||||
URI string `json:"uri,omitempty"`
|
|
||||||
|
|
||||||
// Pin used to access the PKCS11 module. It can be defined in the URI using
|
|
||||||
// the pin-value or pin-source properties.
|
|
||||||
Pin string `json:"pin,omitempty"`
|
|
||||||
|
|
||||||
// ManagementKey used in YubiKeys. Default management key is the hexadecimal
|
|
||||||
// string 010203040506070801020304050607080102030405060708:
|
|
||||||
// []byte{
|
|
||||||
// 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
|
||||||
// 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
|
||||||
// 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
|
||||||
// }
|
|
||||||
ManagementKey string `json:"managementKey,omitempty"`
|
|
||||||
|
|
||||||
// Region to use in AmazonKMS.
|
|
||||||
Region string `json:"region,omitempty"`
|
|
||||||
|
|
||||||
// Profile to use in AmazonKMS.
|
|
||||||
Profile string `json:"profile,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate checks the fields in Options.
|
|
||||||
func (o *Options) Validate() error {
|
|
||||||
if o == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch Type(strings.ToLower(o.Type)) {
|
|
||||||
case DefaultKMS, SoftKMS: // Go crypto based kms.
|
|
||||||
case CloudKMS, AmazonKMS, AzureKMS: // Cloud based kms.
|
|
||||||
case YubiKey, PKCS11: // Hardware based kms.
|
|
||||||
case SSHAgentKMS: // Others
|
|
||||||
default:
|
|
||||||
return errors.Errorf("unsupported kms type %s", o.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,76 +0,0 @@
|
||||||
package apiv1
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestOptions_Validate(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
options *Options
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"nil", nil, false},
|
|
||||||
{"softkms", &Options{Type: "softkms"}, false},
|
|
||||||
{"cloudkms", &Options{Type: "cloudkms"}, false},
|
|
||||||
{"awskms", &Options{Type: "awskms"}, false},
|
|
||||||
{"sshagentkms", &Options{Type: "sshagentkms"}, false},
|
|
||||||
{"pkcs11", &Options{Type: "pkcs11"}, false},
|
|
||||||
{"unsupported", &Options{Type: "unsupported"}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if err := tt.options.Validate(); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrNotImplemented_Error(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
msg string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"default", fields{}, "not implemented"},
|
|
||||||
{"custom", fields{"custom message: not implemented"}, "custom message: not implemented"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
e := ErrNotImplemented{
|
|
||||||
Message: tt.fields.msg,
|
|
||||||
}
|
|
||||||
if got := e.Error(); got != tt.want {
|
|
||||||
t.Errorf("ErrNotImplemented.Error() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrAlreadyExists_Error(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
msg string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"default", fields{}, "key already exists"},
|
|
||||||
{"custom", fields{"custom message: key already exists"}, "custom message: key already exists"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
e := ErrAlreadyExists{
|
|
||||||
Message: tt.fields.msg,
|
|
||||||
}
|
|
||||||
if got := e.Error(); got != tt.want {
|
|
||||||
t.Errorf("ErrAlreadyExists.Error() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,27 +0,0 @@
|
||||||
package apiv1
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
var registry = new(sync.Map)
|
|
||||||
|
|
||||||
// KeyManagerNewFunc is the type that represents the method to initialize a new
|
|
||||||
// KeyManager.
|
|
||||||
type KeyManagerNewFunc func(ctx context.Context, opts Options) (KeyManager, error)
|
|
||||||
|
|
||||||
// Register adds to the registry a method to create a KeyManager of type t.
|
|
||||||
func Register(t Type, fn KeyManagerNewFunc) {
|
|
||||||
registry.Store(t, fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadKeyManagerNewFunc returns the function initialize a KayManager.
|
|
||||||
func LoadKeyManagerNewFunc(t Type) (KeyManagerNewFunc, bool) {
|
|
||||||
v, ok := registry.Load(t)
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
fn, ok := v.(KeyManagerNewFunc)
|
|
||||||
return fn, ok
|
|
||||||
}
|
|
|
@ -1,167 +0,0 @@
|
||||||
package apiv1
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/x509"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ProtectionLevel specifies on some KMS how cryptographic operations are
|
|
||||||
// performed.
|
|
||||||
type ProtectionLevel int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Protection level not specified.
|
|
||||||
UnspecifiedProtectionLevel ProtectionLevel = iota
|
|
||||||
// Crypto operations are performed in software.
|
|
||||||
Software
|
|
||||||
// Crypto operations are performed in a Hardware Security Module.
|
|
||||||
HSM
|
|
||||||
)
|
|
||||||
|
|
||||||
// String returns a string representation of p.
|
|
||||||
func (p ProtectionLevel) String() string {
|
|
||||||
switch p {
|
|
||||||
case UnspecifiedProtectionLevel:
|
|
||||||
return "unspecified"
|
|
||||||
case Software:
|
|
||||||
return "software"
|
|
||||||
case HSM:
|
|
||||||
return "hsm"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("unknown(%d)", p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignatureAlgorithm used for cryptographic signing.
|
|
||||||
type SignatureAlgorithm int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Not specified.
|
|
||||||
UnspecifiedSignAlgorithm SignatureAlgorithm = iota
|
|
||||||
// RSASSA-PKCS1-v1_5 key and a SHA256 digest.
|
|
||||||
SHA256WithRSA
|
|
||||||
// RSASSA-PKCS1-v1_5 key and a SHA384 digest.
|
|
||||||
SHA384WithRSA
|
|
||||||
// RSASSA-PKCS1-v1_5 key and a SHA512 digest.
|
|
||||||
SHA512WithRSA
|
|
||||||
// RSASSA-PSS key with a SHA256 digest.
|
|
||||||
SHA256WithRSAPSS
|
|
||||||
// RSASSA-PSS key with a SHA384 digest.
|
|
||||||
SHA384WithRSAPSS
|
|
||||||
// RSASSA-PSS key with a SHA512 digest.
|
|
||||||
SHA512WithRSAPSS
|
|
||||||
// ECDSA on the NIST P-256 curve with a SHA256 digest.
|
|
||||||
ECDSAWithSHA256
|
|
||||||
// ECDSA on the NIST P-384 curve with a SHA384 digest.
|
|
||||||
ECDSAWithSHA384
|
|
||||||
// ECDSA on the NIST P-521 curve with a SHA512 digest.
|
|
||||||
ECDSAWithSHA512
|
|
||||||
// EdDSA on Curve25519 with a SHA512 digest.
|
|
||||||
PureEd25519
|
|
||||||
)
|
|
||||||
|
|
||||||
// String returns a string representation of s.
|
|
||||||
func (s SignatureAlgorithm) String() string {
|
|
||||||
switch s {
|
|
||||||
case UnspecifiedSignAlgorithm:
|
|
||||||
return "unspecified"
|
|
||||||
case SHA256WithRSA:
|
|
||||||
return "SHA256-RSA"
|
|
||||||
case SHA384WithRSA:
|
|
||||||
return "SHA384-RSA"
|
|
||||||
case SHA512WithRSA:
|
|
||||||
return "SHA512-RSA"
|
|
||||||
case SHA256WithRSAPSS:
|
|
||||||
return "SHA256-RSAPSS"
|
|
||||||
case SHA384WithRSAPSS:
|
|
||||||
return "SHA384-RSAPSS"
|
|
||||||
case SHA512WithRSAPSS:
|
|
||||||
return "SHA512-RSAPSS"
|
|
||||||
case ECDSAWithSHA256:
|
|
||||||
return "ECDSA-SHA256"
|
|
||||||
case ECDSAWithSHA384:
|
|
||||||
return "ECDSA-SHA384"
|
|
||||||
case ECDSAWithSHA512:
|
|
||||||
return "ECDSA-SHA512"
|
|
||||||
case PureEd25519:
|
|
||||||
return "Ed25519"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("unknown(%d)", s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPublicKeyRequest is the parameter used in the kms.GetPublicKey method.
|
|
||||||
type GetPublicKeyRequest struct {
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKeyRequest is the parameter used in the kms.CreateKey method.
|
|
||||||
type CreateKeyRequest struct {
|
|
||||||
// Name represents the key name or label used to identify a key.
|
|
||||||
//
|
|
||||||
// Used by: awskms, cloudkms, azurekms, pkcs11, yubikey.
|
|
||||||
Name string
|
|
||||||
|
|
||||||
// SignatureAlgorithm represents the type of key to create.
|
|
||||||
SignatureAlgorithm SignatureAlgorithm
|
|
||||||
|
|
||||||
// Bits is the number of bits on RSA keys.
|
|
||||||
Bits int
|
|
||||||
|
|
||||||
// ProtectionLevel specifies how cryptographic operations are performed.
|
|
||||||
// Used by: cloudkms, azurekms.
|
|
||||||
ProtectionLevel ProtectionLevel
|
|
||||||
|
|
||||||
// Extractable defines if the new key may be exported from the HSM under a
|
|
||||||
// wrap key. On pkcs11 sets the CKA_EXTRACTABLE bit.
|
|
||||||
//
|
|
||||||
// Used by: pkcs11
|
|
||||||
Extractable bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKeyResponse is the response value of the kms.CreateKey method.
|
|
||||||
type CreateKeyResponse struct {
|
|
||||||
Name string
|
|
||||||
PublicKey crypto.PublicKey
|
|
||||||
PrivateKey crypto.PrivateKey
|
|
||||||
CreateSignerRequest CreateSignerRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSignerRequest is the parameter used in the kms.CreateSigner method.
|
|
||||||
type CreateSignerRequest struct {
|
|
||||||
Signer crypto.Signer
|
|
||||||
SigningKey string
|
|
||||||
SigningKeyPEM []byte
|
|
||||||
TokenLabel string
|
|
||||||
PublicKey string
|
|
||||||
PublicKeyPEM []byte
|
|
||||||
Password []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateDecrypterRequest is the parameter used in the kms.Decrypt method.
|
|
||||||
type CreateDecrypterRequest struct {
|
|
||||||
Decrypter crypto.Decrypter
|
|
||||||
DecryptionKey string
|
|
||||||
DecryptionKeyPEM []byte
|
|
||||||
Password []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadCertificateRequest is the parameter used in the LoadCertificate method of
|
|
||||||
// a CertificateManager.
|
|
||||||
type LoadCertificateRequest struct {
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
// StoreCertificateRequest is the parameter used in the StoreCertificate method
|
|
||||||
// of a CertificateManager.
|
|
||||||
type StoreCertificateRequest struct {
|
|
||||||
Name string
|
|
||||||
Certificate *x509.Certificate
|
|
||||||
|
|
||||||
// Extractable defines if the new certificate may be exported from the HSM
|
|
||||||
// under a wrap key. On pkcs11 sets the CKA_EXTRACTABLE bit.
|
|
||||||
//
|
|
||||||
// Used by: pkcs11
|
|
||||||
Extractable bool
|
|
||||||
}
|
|
|
@ -1,51 +0,0 @@
|
||||||
package apiv1
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestProtectionLevel_String(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
p ProtectionLevel
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"unspecified", UnspecifiedProtectionLevel, "unspecified"},
|
|
||||||
{"software", Software, "software"},
|
|
||||||
{"hsm", HSM, "hsm"},
|
|
||||||
{"unknown", ProtectionLevel(100), "unknown(100)"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := tt.p.String(); got != tt.want {
|
|
||||||
t.Errorf("ProtectionLevel.String() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSignatureAlgorithm_String(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
s SignatureAlgorithm
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"UnspecifiedSignAlgorithm", UnspecifiedSignAlgorithm, "unspecified"},
|
|
||||||
{"SHA256WithRSA", SHA256WithRSA, "SHA256-RSA"},
|
|
||||||
{"SHA384WithRSA", SHA384WithRSA, "SHA384-RSA"},
|
|
||||||
{"SHA512WithRSA", SHA512WithRSA, "SHA512-RSA"},
|
|
||||||
{"SHA256WithRSAPSS", SHA256WithRSAPSS, "SHA256-RSAPSS"},
|
|
||||||
{"SHA384WithRSAPSS", SHA384WithRSAPSS, "SHA384-RSAPSS"},
|
|
||||||
{"SHA512WithRSAPSS", SHA512WithRSAPSS, "SHA512-RSAPSS"},
|
|
||||||
{"ECDSAWithSHA256", ECDSAWithSHA256, "ECDSA-SHA256"},
|
|
||||||
{"ECDSAWithSHA384", ECDSAWithSHA384, "ECDSA-SHA384"},
|
|
||||||
{"ECDSAWithSHA512", ECDSAWithSHA512, "ECDSA-SHA512"},
|
|
||||||
{"PureEd25519", PureEd25519, "Ed25519"},
|
|
||||||
{"unknown", SignatureAlgorithm(100), "unknown(100)"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := tt.s.String(); got != tt.want {
|
|
||||||
t.Errorf("SignatureAlgorithm.String() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,267 +0,0 @@
|
||||||
package awskms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
|
||||||
"github.com/aws/aws-sdk-go/aws/request"
|
|
||||||
"github.com/aws/aws-sdk-go/aws/session"
|
|
||||||
"github.com/aws/aws-sdk-go/service/kms"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"github.com/smallstep/certificates/kms/uri"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Scheme is the scheme used in uris.
|
|
||||||
const Scheme = "awskms"
|
|
||||||
|
|
||||||
// KMS implements a KMS using AWS Key Management Service.
|
|
||||||
type KMS struct {
|
|
||||||
session *session.Session
|
|
||||||
service KeyManagementClient
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeyManagementClient defines the methods on KeyManagementClient that this
|
|
||||||
// package will use. This interface will be used for unit testing.
|
|
||||||
type KeyManagementClient interface {
|
|
||||||
GetPublicKeyWithContext(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error)
|
|
||||||
CreateKeyWithContext(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error)
|
|
||||||
CreateAliasWithContext(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error)
|
|
||||||
SignWithContext(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// customerMasterKeySpecMapping is a mapping between the step signature algorithm,
|
|
||||||
// and bits for RSA keys, with awskms CustomerMasterKeySpec.
|
|
||||||
var customerMasterKeySpecMapping = map[apiv1.SignatureAlgorithm]interface{}{
|
|
||||||
apiv1.UnspecifiedSignAlgorithm: kms.CustomerMasterKeySpecEccNistP256,
|
|
||||||
apiv1.SHA256WithRSA: map[int]string{
|
|
||||||
0: kms.CustomerMasterKeySpecRsa3072,
|
|
||||||
2048: kms.CustomerMasterKeySpecRsa2048,
|
|
||||||
3072: kms.CustomerMasterKeySpecRsa3072,
|
|
||||||
4096: kms.CustomerMasterKeySpecRsa4096,
|
|
||||||
},
|
|
||||||
apiv1.SHA512WithRSA: map[int]string{
|
|
||||||
0: kms.CustomerMasterKeySpecRsa4096,
|
|
||||||
4096: kms.CustomerMasterKeySpecRsa4096,
|
|
||||||
},
|
|
||||||
apiv1.SHA256WithRSAPSS: map[int]string{
|
|
||||||
0: kms.CustomerMasterKeySpecRsa3072,
|
|
||||||
2048: kms.CustomerMasterKeySpecRsa2048,
|
|
||||||
3072: kms.CustomerMasterKeySpecRsa3072,
|
|
||||||
4096: kms.CustomerMasterKeySpecRsa4096,
|
|
||||||
},
|
|
||||||
apiv1.SHA512WithRSAPSS: map[int]string{
|
|
||||||
0: kms.CustomerMasterKeySpecRsa4096,
|
|
||||||
4096: kms.CustomerMasterKeySpecRsa4096,
|
|
||||||
},
|
|
||||||
apiv1.ECDSAWithSHA256: kms.CustomerMasterKeySpecEccNistP256,
|
|
||||||
apiv1.ECDSAWithSHA384: kms.CustomerMasterKeySpecEccNistP384,
|
|
||||||
apiv1.ECDSAWithSHA512: kms.CustomerMasterKeySpecEccNistP521,
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new AWSKMS. By default, sessions will be created using the
|
|
||||||
// credentials in `~/.aws/credentials`, but this can be overridden using the
|
|
||||||
// CredentialsFile option, the Region and Profile can also be configured as
|
|
||||||
// options.
|
|
||||||
//
|
|
||||||
// AWS sessions can also be configured with environment variables, see docs at
|
|
||||||
// https://docs.aws.amazon.com/sdk-for-go/api/aws/session/ for all the options.
|
|
||||||
func New(ctx context.Context, opts apiv1.Options) (*KMS, error) {
|
|
||||||
var o session.Options
|
|
||||||
|
|
||||||
if opts.URI != "" {
|
|
||||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
o.Profile = u.Get("profile")
|
|
||||||
if v := u.Get("region"); v != "" {
|
|
||||||
o.Config.Region = new(string)
|
|
||||||
*o.Config.Region = v
|
|
||||||
}
|
|
||||||
if f := u.Get("credentials-file"); f != "" {
|
|
||||||
o.SharedConfigFiles = []string{f}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated way to set configuration parameters.
|
|
||||||
if opts.Region != "" {
|
|
||||||
o.Config.Region = &opts.Region
|
|
||||||
}
|
|
||||||
if opts.Profile != "" {
|
|
||||||
o.Profile = opts.Profile
|
|
||||||
}
|
|
||||||
if opts.CredentialsFile != "" {
|
|
||||||
o.SharedConfigFiles = []string{opts.CredentialsFile}
|
|
||||||
}
|
|
||||||
|
|
||||||
sess, err := session.NewSessionWithOptions(o)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error creating AWS session")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &KMS{
|
|
||||||
session: sess,
|
|
||||||
service: kms.New(sess),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
apiv1.Register(apiv1.AmazonKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
|
||||||
return New(ctx, opts)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPublicKey returns a public key from KMS.
|
|
||||||
func (k *KMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
|
||||||
if req.Name == "" {
|
|
||||||
return nil, errors.New("getPublicKey 'name' cannot be empty")
|
|
||||||
}
|
|
||||||
keyID, err := parseKeyID(req.Name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
resp, err := k.service.GetPublicKeyWithContext(ctx, &kms.GetPublicKeyInput{
|
|
||||||
KeyId: &keyID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "awskms GetPublicKeyWithContext failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return pemutil.ParseDER(resp.PublicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKey generates a new key in KMS and returns the public key version
|
|
||||||
// of it.
|
|
||||||
func (k *KMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
|
||||||
if req.Name == "" {
|
|
||||||
return nil, errors.New("createKeyRequest 'name' cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
keySpec, err := getCustomerMasterKeySpecMapping(req.SignatureAlgorithm, req.Bits)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tag := new(kms.Tag)
|
|
||||||
tag.SetTagKey("name")
|
|
||||||
tag.SetTagValue(req.Name)
|
|
||||||
|
|
||||||
input := &kms.CreateKeyInput{
|
|
||||||
Description: &req.Name,
|
|
||||||
CustomerMasterKeySpec: &keySpec,
|
|
||||||
Tags: []*kms.Tag{tag},
|
|
||||||
}
|
|
||||||
input.SetKeyUsage(kms.KeyUsageTypeSignVerify)
|
|
||||||
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
resp, err := k.service.CreateKeyWithContext(ctx, input)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "awskms CreateKeyWithContext failed")
|
|
||||||
}
|
|
||||||
if err := k.createKeyAlias(*resp.KeyMetadata.KeyId, req.Name); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create uri for key
|
|
||||||
name := uri.New("awskms", url.Values{
|
|
||||||
"key-id": []string{*resp.KeyMetadata.KeyId},
|
|
||||||
}).String()
|
|
||||||
|
|
||||||
publicKey, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: name,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Names uses Amazon Resource Name
|
|
||||||
// https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html
|
|
||||||
return &apiv1.CreateKeyResponse{
|
|
||||||
Name: name,
|
|
||||||
PublicKey: publicKey,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: name,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *KMS) createKeyAlias(keyID, alias string) error {
|
|
||||||
alias = "alias/" + alias + "-" + keyID[:8]
|
|
||||||
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
_, err := k.service.CreateAliasWithContext(ctx, &kms.CreateAliasInput{
|
|
||||||
AliasName: &alias,
|
|
||||||
TargetKeyId: &keyID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "awskms CreateAliasWithContext failed")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSigner creates a new crypto.Signer with a previously configured key.
|
|
||||||
func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
|
||||||
if req.SigningKey == "" {
|
|
||||||
return nil, errors.New("createSigner 'signingKey' cannot be empty")
|
|
||||||
}
|
|
||||||
return NewSigner(k.service, req.SigningKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the connection of the KMS client.
|
|
||||||
func (k *KMS) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultContext() (context.Context, context.CancelFunc) {
|
|
||||||
return context.WithTimeout(context.Background(), 15*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseKeyID extracts the key-id from an uri.
|
|
||||||
func parseKeyID(name string) (string, error) {
|
|
||||||
name = strings.ToLower(name)
|
|
||||||
if strings.HasPrefix(name, "awskms:") || strings.HasPrefix(name, "aws:") {
|
|
||||||
u, err := uri.Parse(name)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if k := u.Get("key-id"); k != "" {
|
|
||||||
return k, nil
|
|
||||||
}
|
|
||||||
return "", errors.Errorf("failed to get key-id from %s", name)
|
|
||||||
}
|
|
||||||
return name, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCustomerMasterKeySpecMapping(alg apiv1.SignatureAlgorithm, bits int) (string, error) {
|
|
||||||
v, ok := customerMasterKeySpecMapping[alg]
|
|
||||||
if !ok {
|
|
||||||
return "", errors.Errorf("awskms does not support signature algorithm '%s'", alg)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v := v.(type) {
|
|
||||||
case string:
|
|
||||||
return v, nil
|
|
||||||
case map[int]string:
|
|
||||||
s, ok := v[bits]
|
|
||||||
if !ok {
|
|
||||||
return "", errors.Errorf("awskms does not support signature algorithm '%s' with '%d' bits", alg, bits)
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
default:
|
|
||||||
return "", errors.Errorf("unexpected error: this should not happen")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,364 +0,0 @@
|
||||||
package awskms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
|
||||||
"github.com/aws/aws-sdk-go/aws/request"
|
|
||||||
"github.com/aws/aws-sdk-go/aws/session"
|
|
||||||
"github.com/aws/aws-sdk-go/service/kms"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
sess, err := session.NewSessionWithOptions(session.Options{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
expected := &KMS{
|
|
||||||
session: sess,
|
|
||||||
service: kms.New(sess),
|
|
||||||
}
|
|
||||||
|
|
||||||
// This will force an error in the session creation.
|
|
||||||
// It does not fail with missing credentials.
|
|
||||||
forceError := func(t *testing.T) {
|
|
||||||
key := "AWS_CA_BUNDLE"
|
|
||||||
value := os.Getenv(key)
|
|
||||||
os.Setenv(key, filepath.Join(os.TempDir(), "missing-ca.crt"))
|
|
||||||
t.Cleanup(func() {
|
|
||||||
if value == "" {
|
|
||||||
os.Unsetenv(key)
|
|
||||||
} else {
|
|
||||||
os.Setenv(key, value)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
opts apiv1.Options
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *KMS
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{ctx, apiv1.Options{}}, expected, false},
|
|
||||||
{"ok with options", args{ctx, apiv1.Options{
|
|
||||||
Region: "us-east-1",
|
|
||||||
Profile: "smallstep",
|
|
||||||
CredentialsFile: "~/aws/credentials",
|
|
||||||
}}, expected, false},
|
|
||||||
{"ok with uri", args{ctx, apiv1.Options{
|
|
||||||
URI: "awskms:region=us-east-1;profile=smallstep;credentials-file=/var/run/aws/credentials",
|
|
||||||
}}, expected, false},
|
|
||||||
{"fail", args{ctx, apiv1.Options{}}, nil, true},
|
|
||||||
{"fail uri", args{ctx, apiv1.Options{
|
|
||||||
URI: "pkcs11:region=us-east-1;profile=smallstep;credentials-file=/var/run/aws/credentials",
|
|
||||||
}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// Force an error in the session loading
|
|
||||||
if tt.wantErr {
|
|
||||||
forceError(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
got, err := New(tt.args.ctx, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("New() = %#v, want %#v", got, tt.want)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if got.session == nil || got.service == nil {
|
|
||||||
t.Errorf("New() = %#v, want %#v", got, tt.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKMS_GetPublicKey(t *testing.T) {
|
|
||||||
okClient := getOKClient()
|
|
||||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
session *session.Session
|
|
||||||
service KeyManagementClient
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.GetPublicKeyRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want crypto.PublicKey
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
||||||
}}, key, false},
|
|
||||||
{"fail empty", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{}}, nil, true},
|
|
||||||
{"fail name", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "awskms:key-id=",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail getPublicKey", fields{nil, &MockClient{
|
|
||||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
||||||
return nil, fmt.Errorf("an error")
|
|
||||||
},
|
|
||||||
}}, args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail not der", fields{nil, &MockClient{
|
|
||||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
||||||
return &kms.GetPublicKeyOutput{
|
|
||||||
KeyId: input.KeyId,
|
|
||||||
PublicKey: []byte(publicKey),
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}}, args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
||||||
}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &KMS{
|
|
||||||
session: tt.fields.session,
|
|
||||||
service: tt.fields.service,
|
|
||||||
}
|
|
||||||
got, err := k.GetPublicKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("KMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("KMS.GetPublicKey() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKMS_CreateKey(t *testing.T) {
|
|
||||||
okClient := getOKClient()
|
|
||||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
session *session.Session
|
|
||||||
service KeyManagementClient
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateKeyRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want *apiv1.CreateKeyResponse
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "root",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
||||||
PublicKey: key,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok rsa", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "root",
|
|
||||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
|
||||||
Bits: 2048,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
||||||
PublicKey: key,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"fail empty", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{}}, nil, true},
|
|
||||||
{"fail unsupported alg", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "root",
|
|
||||||
SignatureAlgorithm: apiv1.PureEd25519,
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail unsupported bits", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "root",
|
|
||||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
|
||||||
Bits: 1234,
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail createKey", fields{nil, &MockClient{
|
|
||||||
createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
|
|
||||||
return nil, fmt.Errorf("an error")
|
|
||||||
},
|
|
||||||
createAliasWithContext: okClient.createAliasWithContext,
|
|
||||||
getPublicKeyWithContext: okClient.getPublicKeyWithContext,
|
|
||||||
}}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "root",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail createAlias", fields{nil, &MockClient{
|
|
||||||
createKeyWithContext: okClient.createKeyWithContext,
|
|
||||||
createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
|
|
||||||
return nil, fmt.Errorf("an error")
|
|
||||||
},
|
|
||||||
getPublicKeyWithContext: okClient.getPublicKeyWithContext,
|
|
||||||
}}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "root",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail getPublicKey", fields{nil, &MockClient{
|
|
||||||
createKeyWithContext: okClient.createKeyWithContext,
|
|
||||||
createAliasWithContext: okClient.createAliasWithContext,
|
|
||||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
||||||
return nil, fmt.Errorf("an error")
|
|
||||||
},
|
|
||||||
}}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "root",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &KMS{
|
|
||||||
session: tt.fields.session,
|
|
||||||
service: tt.fields.service,
|
|
||||||
}
|
|
||||||
got, err := k.CreateKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("KMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("KMS.CreateKey() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKMS_CreateSigner(t *testing.T) {
|
|
||||||
client := getOKClient()
|
|
||||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
session *session.Session
|
|
||||||
service KeyManagementClient
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateSignerRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want crypto.Signer
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{nil, client}, args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
||||||
}}, &Signer{
|
|
||||||
service: client,
|
|
||||||
keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
||||||
publicKey: key,
|
|
||||||
}, false},
|
|
||||||
{"fail empty", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true},
|
|
||||||
{"fail preload", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &KMS{
|
|
||||||
session: tt.fields.session,
|
|
||||||
service: tt.fields.service,
|
|
||||||
}
|
|
||||||
got, err := k.CreateSigner(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("KMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("KMS.CreateSigner() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKMS_Close(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
session *session.Session
|
|
||||||
service KeyManagementClient
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{nil, getOKClient()}, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &KMS{
|
|
||||||
session: tt.fields.session,
|
|
||||||
service: tt.fields.service,
|
|
||||||
}
|
|
||||||
if err := k.Close(); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("KMS.Close() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_parseKeyID(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok uri", args{"awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
|
||||||
{"ok key id", args{"be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
|
||||||
{"ok arn", args{"arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
|
||||||
{"fail parse", args{"awskms:key-id=%ZZ"}, "", true},
|
|
||||||
{"fail empty key", args{"awskms:key-id="}, "", true},
|
|
||||||
{"fail missing", args{"awskms:foo=bar"}, "", true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := parseKeyID(tt.args.name)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("parseKeyID() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("parseKeyID() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,72 +0,0 @@
|
||||||
package awskms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/pem"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
|
||||||
"github.com/aws/aws-sdk-go/aws/request"
|
|
||||||
"github.com/aws/aws-sdk-go/service/kms"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockClient struct {
|
|
||||||
getPublicKeyWithContext func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error)
|
|
||||||
createKeyWithContext func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error)
|
|
||||||
createAliasWithContext func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error)
|
|
||||||
signWithContext func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) GetPublicKeyWithContext(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
||||||
return m.getPublicKeyWithContext(ctx, input, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) CreateKeyWithContext(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
|
|
||||||
return m.createKeyWithContext(ctx, input, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) CreateAliasWithContext(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
|
|
||||||
return m.createAliasWithContext(ctx, input, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) SignWithContext(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
|
|
||||||
return m.signWithContext(ctx, input, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
publicKey = `-----BEGIN PUBLIC KEY-----
|
|
||||||
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE8XWlIWkOThxNjGbZLYUgRHmsvCrW
|
|
||||||
KF+HLktPfPTIK3lGd1k4849WQs59XIN+LXZQ6b2eRBEBKAHEyQus8UU7gw==
|
|
||||||
-----END PUBLIC KEY-----`
|
|
||||||
keyID = "be468355-ca7a-40d9-a28b-8ae1c4c7f936"
|
|
||||||
)
|
|
||||||
|
|
||||||
var signature = []byte{
|
|
||||||
0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24,
|
|
||||||
0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55,
|
|
||||||
}
|
|
||||||
|
|
||||||
func getOKClient() *MockClient {
|
|
||||||
return &MockClient{
|
|
||||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
||||||
block, _ := pem.Decode([]byte(publicKey))
|
|
||||||
return &kms.GetPublicKeyOutput{
|
|
||||||
KeyId: input.KeyId,
|
|
||||||
PublicKey: block.Bytes,
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
|
|
||||||
md := new(kms.KeyMetadata)
|
|
||||||
md.SetKeyId(keyID)
|
|
||||||
return &kms.CreateKeyOutput{
|
|
||||||
KeyMetadata: md,
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
|
|
||||||
return &kms.CreateAliasOutput{}, nil
|
|
||||||
},
|
|
||||||
signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
|
|
||||||
return &kms.SignOutput{
|
|
||||||
Signature: signature,
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,122 +0,0 @@
|
||||||
package awskms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/rsa"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/service/kms"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Signer implements a crypto.Signer using the AWS KMS.
|
|
||||||
type Signer struct {
|
|
||||||
service KeyManagementClient
|
|
||||||
keyID string
|
|
||||||
publicKey crypto.PublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSigner creates a new signer using a key in the AWS KMS.
|
|
||||||
func NewSigner(svc KeyManagementClient, signingKey string) (*Signer, error) {
|
|
||||||
keyID, err := parseKeyID(signingKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure that the key exists.
|
|
||||||
signer := &Signer{
|
|
||||||
service: svc,
|
|
||||||
keyID: keyID,
|
|
||||||
}
|
|
||||||
if err := signer.preloadKey(keyID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return signer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Signer) preloadKey(keyID string) error {
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
resp, err := s.service.GetPublicKeyWithContext(ctx, &kms.GetPublicKeyInput{
|
|
||||||
KeyId: &keyID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "awskms GetPublicKeyWithContext failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.publicKey, err = pemutil.ParseDER(resp.PublicKey)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Public returns the public key of this signer or an error.
|
|
||||||
func (s *Signer) Public() crypto.PublicKey {
|
|
||||||
return s.publicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign signs digest with the private key stored in the AWS KMS.
|
|
||||||
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
|
||||||
alg, err := getSigningAlgorithm(s.Public(), opts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req := &kms.SignInput{
|
|
||||||
KeyId: &s.keyID,
|
|
||||||
SigningAlgorithm: &alg,
|
|
||||||
Message: digest,
|
|
||||||
}
|
|
||||||
req.SetMessageType("DIGEST")
|
|
||||||
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
resp, err := s.service.SignWithContext(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "awsKMS SignWithContext failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp.Signature, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (string, error) {
|
|
||||||
switch key.(type) {
|
|
||||||
case *rsa.PublicKey:
|
|
||||||
_, isPSS := opts.(*rsa.PSSOptions)
|
|
||||||
switch h := opts.HashFunc(); h {
|
|
||||||
case crypto.SHA256:
|
|
||||||
if isPSS {
|
|
||||||
return kms.SigningAlgorithmSpecRsassaPssSha256, nil
|
|
||||||
}
|
|
||||||
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil
|
|
||||||
case crypto.SHA384:
|
|
||||||
if isPSS {
|
|
||||||
return kms.SigningAlgorithmSpecRsassaPssSha384, nil
|
|
||||||
}
|
|
||||||
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil
|
|
||||||
case crypto.SHA512:
|
|
||||||
if isPSS {
|
|
||||||
return kms.SigningAlgorithmSpecRsassaPssSha512, nil
|
|
||||||
}
|
|
||||||
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil
|
|
||||||
default:
|
|
||||||
return "", errors.Errorf("unsupported hash function %v", h)
|
|
||||||
}
|
|
||||||
case *ecdsa.PublicKey:
|
|
||||||
switch h := opts.HashFunc(); h {
|
|
||||||
case crypto.SHA256:
|
|
||||||
return kms.SigningAlgorithmSpecEcdsaSha256, nil
|
|
||||||
case crypto.SHA384:
|
|
||||||
return kms.SigningAlgorithmSpecEcdsaSha384, nil
|
|
||||||
case crypto.SHA512:
|
|
||||||
return kms.SigningAlgorithmSpecEcdsaSha512, nil
|
|
||||||
default:
|
|
||||||
return "", errors.Errorf("unsupported hash function %v", h)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return "", errors.Errorf("unsupported key type %T", key)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,191 +0,0 @@
|
||||||
package awskms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
|
||||||
"github.com/aws/aws-sdk-go/aws/request"
|
|
||||||
"github.com/aws/aws-sdk-go/service/kms"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewSigner(t *testing.T) {
|
|
||||||
okClient := getOKClient()
|
|
||||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
svc KeyManagementClient
|
|
||||||
signingKey string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *Signer
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{okClient, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, &Signer{
|
|
||||||
service: okClient,
|
|
||||||
keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
||||||
publicKey: key,
|
|
||||||
}, false},
|
|
||||||
{"fail parse", args{okClient, "awskms:key-id="}, nil, true},
|
|
||||||
{"fail preload", args{&MockClient{
|
|
||||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
||||||
return nil, fmt.Errorf("an error")
|
|
||||||
},
|
|
||||||
}, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true},
|
|
||||||
{"fail preload not der", args{&MockClient{
|
|
||||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
||||||
return &kms.GetPublicKeyOutput{
|
|
||||||
KeyId: input.KeyId,
|
|
||||||
PublicKey: []byte(publicKey),
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := NewSigner(tt.args.svc, tt.args.signingKey)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSigner_Public(t *testing.T) {
|
|
||||||
okClient := getOKClient()
|
|
||||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
service KeyManagementClient
|
|
||||||
keyID string
|
|
||||||
publicKey crypto.PublicKey
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want crypto.PublicKey
|
|
||||||
}{
|
|
||||||
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, key},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Signer{
|
|
||||||
service: tt.fields.service,
|
|
||||||
keyID: tt.fields.keyID,
|
|
||||||
publicKey: tt.fields.publicKey,
|
|
||||||
}
|
|
||||||
if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Signer.Public() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSigner_Sign(t *testing.T) {
|
|
||||||
okClient := getOKClient()
|
|
||||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
service KeyManagementClient
|
|
||||||
keyID string
|
|
||||||
publicKey crypto.PublicKey
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
rand io.Reader
|
|
||||||
digest []byte
|
|
||||||
opts crypto.SignerOpts
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want []byte
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, signature, false},
|
|
||||||
{"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
|
|
||||||
{"fail key", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", []byte("key")}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
|
|
||||||
{"fail sign", fields{&MockClient{
|
|
||||||
signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
|
|
||||||
return nil, fmt.Errorf("an error")
|
|
||||||
},
|
|
||||||
}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Signer{
|
|
||||||
service: tt.fields.service,
|
|
||||||
keyID: tt.fields.keyID,
|
|
||||||
publicKey: tt.fields.publicKey,
|
|
||||||
}
|
|
||||||
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_getSigningAlgorithm(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
key crypto.PublicKey
|
|
||||||
opts crypto.SignerOpts
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, "RSASSA_PKCS1_V1_5_SHA_256", false},
|
|
||||||
{"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, "RSASSA_PKCS1_V1_5_SHA_384", false},
|
|
||||||
{"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, "RSASSA_PKCS1_V1_5_SHA_512", false},
|
|
||||||
{"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, "RSASSA_PSS_SHA_256", false},
|
|
||||||
{"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, "RSASSA_PSS_SHA_384", false},
|
|
||||||
{"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, "RSASSA_PSS_SHA_512", false},
|
|
||||||
{"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, "ECDSA_SHA_256", false},
|
|
||||||
{"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, "ECDSA_SHA_384", false},
|
|
||||||
{"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, "ECDSA_SHA_512", false},
|
|
||||||
{"fail type", args{[]byte("key"), crypto.SHA256}, "", true},
|
|
||||||
{"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, "", true},
|
|
||||||
{"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, "", true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := getSigningAlgorithm(tt.args.key, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("getSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("getSigningAlgorithm() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,81 +0,0 @@
|
||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: github.com/smallstep/certificates/kms/azurekms (interfaces: KeyVaultClient)
|
|
||||||
|
|
||||||
// Package mock is a generated GoMock package.
|
|
||||||
package mock
|
|
||||||
|
|
||||||
import (
|
|
||||||
context "context"
|
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
keyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
)
|
|
||||||
|
|
||||||
// KeyVaultClient is a mock of KeyVaultClient interface
|
|
||||||
type KeyVaultClient struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *KeyVaultClientMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeyVaultClientMockRecorder is the mock recorder for KeyVaultClient
|
|
||||||
type KeyVaultClientMockRecorder struct {
|
|
||||||
mock *KeyVaultClient
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewKeyVaultClient creates a new mock instance
|
|
||||||
func NewKeyVaultClient(ctrl *gomock.Controller) *KeyVaultClient {
|
|
||||||
mock := &KeyVaultClient{ctrl: ctrl}
|
|
||||||
mock.recorder = &KeyVaultClientMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use
|
|
||||||
func (m *KeyVaultClient) EXPECT() *KeyVaultClientMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKey mocks base method
|
|
||||||
func (m *KeyVaultClient) CreateKey(arg0 context.Context, arg1, arg2 string, arg3 keyvault.KeyCreateParameters) (keyvault.KeyBundle, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "CreateKey", arg0, arg1, arg2, arg3)
|
|
||||||
ret0, _ := ret[0].(keyvault.KeyBundle)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKey indicates an expected call of CreateKey
|
|
||||||
func (mr *KeyVaultClientMockRecorder) CreateKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateKey", reflect.TypeOf((*KeyVaultClient)(nil).CreateKey), arg0, arg1, arg2, arg3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetKey mocks base method
|
|
||||||
func (m *KeyVaultClient) GetKey(arg0 context.Context, arg1, arg2, arg3 string) (keyvault.KeyBundle, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetKey", arg0, arg1, arg2, arg3)
|
|
||||||
ret0, _ := ret[0].(keyvault.KeyBundle)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetKey indicates an expected call of GetKey
|
|
||||||
func (mr *KeyVaultClientMockRecorder) GetKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKey", reflect.TypeOf((*KeyVaultClient)(nil).GetKey), arg0, arg1, arg2, arg3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign mocks base method
|
|
||||||
func (m *KeyVaultClient) Sign(arg0 context.Context, arg1, arg2, arg3 string, arg4 keyvault.KeySignParameters) (keyvault.KeyOperationResult, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "Sign", arg0, arg1, arg2, arg3, arg4)
|
|
||||||
ret0, _ := ret[0].(keyvault.KeyOperationResult)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign indicates an expected call of Sign
|
|
||||||
func (mr *KeyVaultClientMockRecorder) Sign(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sign", reflect.TypeOf((*KeyVaultClient)(nil).Sign), arg0, arg1, arg2, arg3, arg4)
|
|
||||||
}
|
|
|
@ -1,342 +0,0 @@
|
||||||
package azurekms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"regexp"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
|
||||||
"github.com/Azure/go-autorest/autorest/azure"
|
|
||||||
"github.com/Azure/go-autorest/autorest/azure/auth"
|
|
||||||
"github.com/Azure/go-autorest/autorest/date"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"github.com/smallstep/certificates/kms/uri"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
apiv1.Register(apiv1.AzureKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
|
||||||
return New(ctx, opts)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scheme is the scheme used for the Azure Key Vault uris.
|
|
||||||
const Scheme = "azurekms"
|
|
||||||
|
|
||||||
// keyIDRegexp is the regular expression that Key Vault uses on the kid. We can
|
|
||||||
// extract the vault, name and version of the key.
|
|
||||||
var keyIDRegexp = regexp.MustCompile(`^https://([0-9a-zA-Z-]+)\.vault\.azure\.net/keys/([0-9a-zA-Z-]+)/([0-9a-zA-Z-]+)$`)
|
|
||||||
|
|
||||||
var (
|
|
||||||
valueTrue = true
|
|
||||||
value2048 int32 = 2048
|
|
||||||
value3072 int32 = 3072
|
|
||||||
value4096 int32 = 4096
|
|
||||||
)
|
|
||||||
|
|
||||||
var now = func() time.Time {
|
|
||||||
return time.Now().UTC()
|
|
||||||
}
|
|
||||||
|
|
||||||
type keyType struct {
|
|
||||||
Kty keyvault.JSONWebKeyType
|
|
||||||
Curve keyvault.JSONWebKeyCurveName
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k keyType) KeyType(pl apiv1.ProtectionLevel) keyvault.JSONWebKeyType {
|
|
||||||
switch k.Kty {
|
|
||||||
case keyvault.EC:
|
|
||||||
if pl == apiv1.HSM {
|
|
||||||
return keyvault.ECHSM
|
|
||||||
}
|
|
||||||
return k.Kty
|
|
||||||
case keyvault.RSA:
|
|
||||||
if pl == apiv1.HSM {
|
|
||||||
return keyvault.RSAHSM
|
|
||||||
}
|
|
||||||
return k.Kty
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]keyType{
|
|
||||||
apiv1.UnspecifiedSignAlgorithm: {
|
|
||||||
Kty: keyvault.EC,
|
|
||||||
Curve: keyvault.P256,
|
|
||||||
},
|
|
||||||
apiv1.SHA256WithRSA: {
|
|
||||||
Kty: keyvault.RSA,
|
|
||||||
},
|
|
||||||
apiv1.SHA384WithRSA: {
|
|
||||||
Kty: keyvault.RSA,
|
|
||||||
},
|
|
||||||
apiv1.SHA512WithRSA: {
|
|
||||||
Kty: keyvault.RSA,
|
|
||||||
},
|
|
||||||
apiv1.SHA256WithRSAPSS: {
|
|
||||||
Kty: keyvault.RSA,
|
|
||||||
},
|
|
||||||
apiv1.SHA384WithRSAPSS: {
|
|
||||||
Kty: keyvault.RSA,
|
|
||||||
},
|
|
||||||
apiv1.SHA512WithRSAPSS: {
|
|
||||||
Kty: keyvault.RSA,
|
|
||||||
},
|
|
||||||
apiv1.ECDSAWithSHA256: {
|
|
||||||
Kty: keyvault.EC,
|
|
||||||
Curve: keyvault.P256,
|
|
||||||
},
|
|
||||||
apiv1.ECDSAWithSHA384: {
|
|
||||||
Kty: keyvault.EC,
|
|
||||||
Curve: keyvault.P384,
|
|
||||||
},
|
|
||||||
apiv1.ECDSAWithSHA512: {
|
|
||||||
Kty: keyvault.EC,
|
|
||||||
Curve: keyvault.P521,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// vaultResource is the value the client will use as audience.
|
|
||||||
const vaultResource = "https://vault.azure.net"
|
|
||||||
|
|
||||||
// KeyVaultClient is the interface implemented by keyvault.BaseClient. It will
|
|
||||||
// be used for testing purposes.
|
|
||||||
type KeyVaultClient interface {
|
|
||||||
GetKey(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string) (keyvault.KeyBundle, error)
|
|
||||||
CreateKey(ctx context.Context, vaultBaseURL string, keyName string, parameters keyvault.KeyCreateParameters) (keyvault.KeyBundle, error)
|
|
||||||
Sign(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string, parameters keyvault.KeySignParameters) (keyvault.KeyOperationResult, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeyVault implements a KMS using Azure Key Vault.
|
|
||||||
//
|
|
||||||
// The URI format used in Azure Key Vault is the following:
|
|
||||||
//
|
|
||||||
// - azurekms:name=key-name;vault=vault-name
|
|
||||||
// - azurekms:name=key-name;vault=vault-name?version=key-version
|
|
||||||
// - azurekms:name=key-name;vault=vault-name?hsm=true
|
|
||||||
//
|
|
||||||
// The scheme is "azurekms"; "name" is the key name; "vault" is the key vault
|
|
||||||
// name where the key is located; "version" is an optional parameter that
|
|
||||||
// defines the version of they key, if version is not given, the latest one will
|
|
||||||
// be used; "hsm" defines if an HSM want to be used for this key, this is
|
|
||||||
// specially useful when this is used from `step`.
|
|
||||||
//
|
|
||||||
// TODO(mariano): The implementation is using /services/keyvault/v7.1/keyvault
|
|
||||||
// package, at some point Azure might create a keyvault client with all the
|
|
||||||
// functionality in /sdk/keyvault, we should migrate to that once available.
|
|
||||||
type KeyVault struct {
|
|
||||||
baseClient KeyVaultClient
|
|
||||||
defaults DefaultOptions
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultOptions are custom options that can be passed as defaults using the
|
|
||||||
// URI in apiv1.Options.
|
|
||||||
type DefaultOptions struct {
|
|
||||||
Vault string
|
|
||||||
ProtectionLevel apiv1.ProtectionLevel
|
|
||||||
}
|
|
||||||
|
|
||||||
var createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
|
||||||
baseClient := keyvault.New()
|
|
||||||
|
|
||||||
// With an URI, try to log in only using client credentials in the URI.
|
|
||||||
// Client credentials requires:
|
|
||||||
// - client-id
|
|
||||||
// - client-secret
|
|
||||||
// - tenant-id
|
|
||||||
// And optionally the aad-endpoint to support custom clouds:
|
|
||||||
// - aad-endpoint (defaults to https://login.microsoftonline.com/)
|
|
||||||
if opts.URI != "" {
|
|
||||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Required options
|
|
||||||
clientID := u.Get("client-id")
|
|
||||||
clientSecret := u.Get("client-secret")
|
|
||||||
tenantID := u.Get("tenant-id")
|
|
||||||
// optional
|
|
||||||
aadEndpoint := u.Get("aad-endpoint")
|
|
||||||
|
|
||||||
if clientID != "" && clientSecret != "" && tenantID != "" {
|
|
||||||
s := auth.EnvironmentSettings{
|
|
||||||
Values: map[string]string{
|
|
||||||
auth.ClientID: clientID,
|
|
||||||
auth.ClientSecret: clientSecret,
|
|
||||||
auth.TenantID: tenantID,
|
|
||||||
auth.Resource: vaultResource,
|
|
||||||
},
|
|
||||||
Environment: azure.PublicCloud,
|
|
||||||
}
|
|
||||||
if aadEndpoint != "" {
|
|
||||||
s.Environment.ActiveDirectoryEndpoint = aadEndpoint
|
|
||||||
}
|
|
||||||
baseClient.Authorizer, err = s.GetAuthorizer()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return baseClient, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to authorize with the following methods:
|
|
||||||
// 1. Environment variables.
|
|
||||||
// - Client credentials
|
|
||||||
// - Client certificate
|
|
||||||
// - Username and password
|
|
||||||
// - MSI
|
|
||||||
// 2. Using Azure CLI 2.0 on local development.
|
|
||||||
authorizer, err := auth.NewAuthorizerFromEnvironmentWithResource(vaultResource)
|
|
||||||
if err != nil {
|
|
||||||
authorizer, err = auth.NewAuthorizerFromCLIWithResource(vaultResource)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error getting authorizer for key vault")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
baseClient.Authorizer = authorizer
|
|
||||||
return &baseClient, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// New initializes a new KMS implemented using Azure Key Vault.
|
|
||||||
func New(ctx context.Context, opts apiv1.Options) (*KeyVault, error) {
|
|
||||||
baseClient, err := createClient(ctx, opts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// step and step-ca do not need and URI, but having a default vault and
|
|
||||||
// protection level is useful if this package is used as an api
|
|
||||||
var defaults DefaultOptions
|
|
||||||
if opts.URI != "" {
|
|
||||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defaults.Vault = u.Get("vault")
|
|
||||||
if u.GetBool("hsm") {
|
|
||||||
defaults.ProtectionLevel = apiv1.HSM
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &KeyVault{
|
|
||||||
baseClient: baseClient,
|
|
||||||
defaults: defaults,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPublicKey loads a public key from Azure Key Vault by its resource name.
|
|
||||||
func (k *KeyVault) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
|
||||||
if req.Name == "" {
|
|
||||||
return nil, errors.New("getPublicKeyRequest 'name' cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
vault, name, version, _, err := parseKeyName(req.Name, k.defaults)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
resp, err := k.baseClient.GetKey(ctx, vaultBaseURL(vault), name, version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "keyVault GetKey failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return convertKey(resp.Key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKey creates a asymmetric key in Azure Key Vault.
|
|
||||||
func (k *KeyVault) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
|
||||||
if req.Name == "" {
|
|
||||||
return nil, errors.New("createKeyRequest 'name' cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
vault, name, _, hsm, err := parseKeyName(req.Name, k.defaults)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Override protection level to HSM only if it's not specified, and is given
|
|
||||||
// in the uri.
|
|
||||||
protectionLevel := req.ProtectionLevel
|
|
||||||
if protectionLevel == apiv1.UnspecifiedProtectionLevel && hsm {
|
|
||||||
protectionLevel = apiv1.HSM
|
|
||||||
}
|
|
||||||
|
|
||||||
kt, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("keyVault does not support signature algorithm '%s'", req.SignatureAlgorithm)
|
|
||||||
}
|
|
||||||
var keySize *int32
|
|
||||||
if kt.Kty == keyvault.RSA || kt.Kty == keyvault.RSAHSM {
|
|
||||||
switch req.Bits {
|
|
||||||
case 2048:
|
|
||||||
keySize = &value2048
|
|
||||||
case 0, 3072:
|
|
||||||
keySize = &value3072
|
|
||||||
case 4096:
|
|
||||||
keySize = &value4096
|
|
||||||
default:
|
|
||||||
return nil, errors.Errorf("keyVault does not support key size %d", req.Bits)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
created := date.UnixTime(now())
|
|
||||||
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
resp, err := k.baseClient.CreateKey(ctx, vaultBaseURL(vault), name, keyvault.KeyCreateParameters{
|
|
||||||
Kty: kt.KeyType(protectionLevel),
|
|
||||||
KeySize: keySize,
|
|
||||||
Curve: kt.Curve,
|
|
||||||
KeyOps: &[]keyvault.JSONWebKeyOperation{
|
|
||||||
keyvault.Sign, keyvault.Verify,
|
|
||||||
},
|
|
||||||
KeyAttributes: &keyvault.KeyAttributes{
|
|
||||||
Enabled: &valueTrue,
|
|
||||||
Created: &created,
|
|
||||||
NotBefore: &created,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "keyVault CreateKey failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
publicKey, err := convertKey(resp.Key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
keyURI := getKeyName(vault, name, resp)
|
|
||||||
return &apiv1.CreateKeyResponse{
|
|
||||||
Name: keyURI,
|
|
||||||
PublicKey: publicKey,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: keyURI,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSigner returns a crypto.Signer from a previously created asymmetric key.
|
|
||||||
func (k *KeyVault) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
|
||||||
if req.SigningKey == "" {
|
|
||||||
return nil, errors.New("createSignerRequest 'signingKey' cannot be empty")
|
|
||||||
}
|
|
||||||
return NewSigner(k.baseClient, req.SigningKey, k.defaults)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the client connection to the Azure Key Vault. This is a noop.
|
|
||||||
func (k *KeyVault) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateName validates that the given string is a valid URI.
|
|
||||||
func (k *KeyVault) ValidateName(s string) error {
|
|
||||||
_, _, _, _, err := parseKeyName(s, k.defaults)
|
|
||||||
return err
|
|
||||||
}
|
|
|
@ -1,653 +0,0 @@
|
||||||
//go:generate mockgen -package mock -mock_names=KeyVaultClient=KeyVaultClient -destination internal/mock/key_vault_client.go github.com/smallstep/certificates/kms/azurekms KeyVaultClient
|
|
||||||
package azurekms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
|
||||||
"github.com/Azure/go-autorest/autorest/date"
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"github.com/smallstep/certificates/kms/azurekms/internal/mock"
|
|
||||||
"go.step.sm/crypto/keyutil"
|
|
||||||
"gopkg.in/square/go-jose.v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
var errTest = fmt.Errorf("test error")
|
|
||||||
|
|
||||||
func mockNow(t *testing.T) time.Time {
|
|
||||||
old := now
|
|
||||||
t0 := time.Unix(1234567890, 123).UTC()
|
|
||||||
now = func() time.Time {
|
|
||||||
return t0
|
|
||||||
}
|
|
||||||
t.Cleanup(func() {
|
|
||||||
now = old
|
|
||||||
})
|
|
||||||
return t0
|
|
||||||
}
|
|
||||||
|
|
||||||
func mockClient(t *testing.T) *mock.KeyVaultClient {
|
|
||||||
t.Helper()
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
ctrl.Finish()
|
|
||||||
})
|
|
||||||
return mock.NewKeyVaultClient(ctrl)
|
|
||||||
}
|
|
||||||
|
|
||||||
func createJWK(t *testing.T, pub crypto.PublicKey) *keyvault.JSONWebKey {
|
|
||||||
t.Helper()
|
|
||||||
b, err := json.Marshal(&jose.JSONWebKey{
|
|
||||||
Key: pub,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
key := new(keyvault.JSONWebKey)
|
|
||||||
if err := json.Unmarshal(b, key); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return key
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_now(t *testing.T) {
|
|
||||||
t0 := now()
|
|
||||||
if loc := t0.Location(); loc != time.UTC {
|
|
||||||
t.Errorf("now() Location = %v, want %v", loc, time.UTC)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
client := mockClient(t)
|
|
||||||
old := createClient
|
|
||||||
t.Cleanup(func() {
|
|
||||||
createClient = old
|
|
||||||
})
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
opts apiv1.Options
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setup func()
|
|
||||||
args args
|
|
||||||
want *KeyVault
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", func() {
|
|
||||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
}, args{context.Background(), apiv1.Options{}}, &KeyVault{
|
|
||||||
baseClient: client,
|
|
||||||
}, false},
|
|
||||||
{"ok with vault", func() {
|
|
||||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
}, args{context.Background(), apiv1.Options{
|
|
||||||
URI: "azurekms:vault=my-vault",
|
|
||||||
}}, &KeyVault{
|
|
||||||
baseClient: client,
|
|
||||||
defaults: DefaultOptions{
|
|
||||||
Vault: "my-vault",
|
|
||||||
ProtectionLevel: apiv1.UnspecifiedProtectionLevel,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok with vault + hsm", func() {
|
|
||||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
}, args{context.Background(), apiv1.Options{
|
|
||||||
URI: "azurekms:vault=my-vault;hsm=true",
|
|
||||||
}}, &KeyVault{
|
|
||||||
baseClient: client,
|
|
||||||
defaults: DefaultOptions{
|
|
||||||
Vault: "my-vault",
|
|
||||||
ProtectionLevel: apiv1.HSM,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"fail", func() {
|
|
||||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
|
||||||
return nil, errTest
|
|
||||||
}
|
|
||||||
}, args{context.Background(), apiv1.Options{}}, nil, true},
|
|
||||||
{"fail uri", func() {
|
|
||||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
}, args{context.Background(), apiv1.Options{
|
|
||||||
URI: "kms:vault=my-vault;hsm=true",
|
|
||||||
}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
tt.setup()
|
|
||||||
got, err := New(tt.args.ctx, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("New() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeyVault_createClient(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
opts apiv1.Options
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
skip bool
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{context.Background(), apiv1.Options{}}, true, false},
|
|
||||||
{"ok with uri", args{context.Background(), apiv1.Options{
|
|
||||||
URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id",
|
|
||||||
}}, false, false},
|
|
||||||
{"ok with uri+aad", args{context.Background(), apiv1.Options{
|
|
||||||
URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id;aad-enpoint=https%3A%2F%2Flogin.microsoftonline.us%2F",
|
|
||||||
}}, false, false},
|
|
||||||
{"ok with uri no config", args{context.Background(), apiv1.Options{
|
|
||||||
URI: "azurekms:",
|
|
||||||
}}, true, false},
|
|
||||||
{"fail uri", args{context.Background(), apiv1.Options{
|
|
||||||
URI: "kms:client-id=id;client-secret=secret;tenant-id=id",
|
|
||||||
}}, false, true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.skip {
|
|
||||||
t.SkipNow()
|
|
||||||
}
|
|
||||||
_, err := createClient(tt.args.ctx, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeyVault_GetPublicKey(t *testing.T) {
|
|
||||||
key, err := keyutil.GenerateDefaultSigner()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pub := key.Public()
|
|
||||||
jwk := createJWK(t, pub)
|
|
||||||
|
|
||||||
client := mockClient(t)
|
|
||||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
|
|
||||||
Key: jwk,
|
|
||||||
}, nil)
|
|
||||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
|
||||||
Key: jwk,
|
|
||||||
}, nil)
|
|
||||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
baseClient KeyVaultClient
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.GetPublicKeyRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want crypto.PublicKey
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
}}, pub, false},
|
|
||||||
{"ok with version", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key?version=my-version",
|
|
||||||
}}, pub, false},
|
|
||||||
{"fail GetKey", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=not-found?version=my-version",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail empty", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail vault", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "azurekms:vault=;name=not-found?version=my-version",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail id", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "azurekms:vault=;name=?version=my-version",
|
|
||||||
}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &KeyVault{
|
|
||||||
baseClient: tt.fields.baseClient,
|
|
||||||
}
|
|
||||||
got, err := k.GetPublicKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("KeyVault.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("KeyVault.GetPublicKey() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeyVault_CreateKey(t *testing.T) {
|
|
||||||
ecKey, err := keyutil.GenerateDefaultSigner()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
rsaKey, err := keyutil.GenerateSigner("RSA", "", 2048)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
ecPub := ecKey.Public()
|
|
||||||
rsaPub := rsaKey.Public()
|
|
||||||
ecJWK := createJWK(t, ecPub)
|
|
||||||
rsaJWK := createJWK(t, rsaPub)
|
|
||||||
|
|
||||||
t0 := date.UnixTime(mockNow(t))
|
|
||||||
client := mockClient(t)
|
|
||||||
|
|
||||||
expects := []struct {
|
|
||||||
Name string
|
|
||||||
Kty keyvault.JSONWebKeyType
|
|
||||||
KeySize *int32
|
|
||||||
Curve keyvault.JSONWebKeyCurveName
|
|
||||||
Key *keyvault.JSONWebKey
|
|
||||||
}{
|
|
||||||
{"P-256", keyvault.EC, nil, keyvault.P256, ecJWK},
|
|
||||||
{"P-256 HSM", keyvault.ECHSM, nil, keyvault.P256, ecJWK},
|
|
||||||
{"P-256 HSM (uri)", keyvault.ECHSM, nil, keyvault.P256, ecJWK},
|
|
||||||
{"P-256 Default", keyvault.EC, nil, keyvault.P256, ecJWK},
|
|
||||||
{"P-384", keyvault.EC, nil, keyvault.P384, ecJWK},
|
|
||||||
{"P-521", keyvault.EC, nil, keyvault.P521, ecJWK},
|
|
||||||
{"RSA 0", keyvault.RSA, &value3072, "", rsaJWK},
|
|
||||||
{"RSA 0 HSM", keyvault.RSAHSM, &value3072, "", rsaJWK},
|
|
||||||
{"RSA 0 HSM (uri)", keyvault.RSAHSM, &value3072, "", rsaJWK},
|
|
||||||
{"RSA 2048", keyvault.RSA, &value2048, "", rsaJWK},
|
|
||||||
{"RSA 3072", keyvault.RSA, &value3072, "", rsaJWK},
|
|
||||||
{"RSA 4096", keyvault.RSA, &value4096, "", rsaJWK},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, e := range expects {
|
|
||||||
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", keyvault.KeyCreateParameters{
|
|
||||||
Kty: e.Kty,
|
|
||||||
KeySize: e.KeySize,
|
|
||||||
Curve: e.Curve,
|
|
||||||
KeyOps: &[]keyvault.JSONWebKeyOperation{
|
|
||||||
keyvault.Sign, keyvault.Verify,
|
|
||||||
},
|
|
||||||
KeyAttributes: &keyvault.KeyAttributes{
|
|
||||||
Enabled: &valueTrue,
|
|
||||||
Created: &t0,
|
|
||||||
NotBefore: &t0,
|
|
||||||
},
|
|
||||||
}).Return(keyvault.KeyBundle{
|
|
||||||
Key: e.Key,
|
|
||||||
}, nil)
|
|
||||||
}
|
|
||||||
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{}, errTest)
|
|
||||||
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{
|
|
||||||
Key: nil,
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
baseClient KeyVaultClient
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateKeyRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want *apiv1.CreateKeyResponse
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok P-256", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
ProtectionLevel: apiv1.Software,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: ecPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok P-256 HSM", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
ProtectionLevel: apiv1.HSM,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: ecPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok P-256 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key?hsm=true",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: ecPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok P-256 Default", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: ecPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok P-384", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA384,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: ecPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok P-521", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA512,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: ecPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok RSA 0", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
Bits: 0,
|
|
||||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
|
||||||
ProtectionLevel: apiv1.Software,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: rsaPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok RSA 0 HSM", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
Bits: 0,
|
|
||||||
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
|
|
||||||
ProtectionLevel: apiv1.HSM,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: rsaPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok RSA 0 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key;hsm=true",
|
|
||||||
Bits: 0,
|
|
||||||
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: rsaPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok RSA 2048", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
Bits: 2048,
|
|
||||||
SignatureAlgorithm: apiv1.SHA384WithRSA,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: rsaPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok RSA 3072", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
Bits: 3072,
|
|
||||||
SignatureAlgorithm: apiv1.SHA512WithRSA,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: rsaPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ok RSA 4096", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
Bits: 4096,
|
|
||||||
SignatureAlgorithm: apiv1.SHA512WithRSAPSS,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
PublicKey: rsaPub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"fail createKey", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=not-found",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail convertKey", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=not-found",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail name", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail vault", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=;name=not-found?version=my-version",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail id", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=?version=my-version",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail SignatureAlgorithm", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=not-found",
|
|
||||||
SignatureAlgorithm: apiv1.PureEd25519,
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail bit size", fields{client}, args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "azurekms:vault=my-vault;name=not-found",
|
|
||||||
SignatureAlgorithm: apiv1.SHA384WithRSAPSS,
|
|
||||||
Bits: 1024,
|
|
||||||
}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &KeyVault{
|
|
||||||
baseClient: tt.fields.baseClient,
|
|
||||||
}
|
|
||||||
got, err := k.CreateKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("KeyVault.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("KeyVault.CreateKey() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeyVault_CreateSigner(t *testing.T) {
|
|
||||||
key, err := keyutil.GenerateDefaultSigner()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pub := key.Public()
|
|
||||||
jwk := createJWK(t, pub)
|
|
||||||
|
|
||||||
client := mockClient(t)
|
|
||||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
|
|
||||||
Key: jwk,
|
|
||||||
}, nil)
|
|
||||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
|
||||||
Key: jwk,
|
|
||||||
}, nil)
|
|
||||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
baseClient KeyVaultClient
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateSignerRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want crypto.Signer
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{client}, args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:vault=my-vault;name=my-key",
|
|
||||||
}}, &Signer{
|
|
||||||
client: client,
|
|
||||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
|
||||||
name: "my-key",
|
|
||||||
version: "",
|
|
||||||
publicKey: pub,
|
|
||||||
}, false},
|
|
||||||
{"ok with version", fields{client}, args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:vault=my-vault;name=my-key;version=my-version",
|
|
||||||
}}, &Signer{
|
|
||||||
client: client,
|
|
||||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
|
||||||
name: "my-key",
|
|
||||||
version: "my-version",
|
|
||||||
publicKey: pub,
|
|
||||||
}, false},
|
|
||||||
{"fail GetKey", fields{client}, args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "azurekms:vault=my-vault;name=not-found;version=my-version",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail SigningKey", fields{client}, args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "",
|
|
||||||
}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &KeyVault{
|
|
||||||
baseClient: tt.fields.baseClient,
|
|
||||||
}
|
|
||||||
got, err := k.CreateSigner(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("KeyVault.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("KeyVault.CreateSigner() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeyVault_Close(t *testing.T) {
|
|
||||||
client := mockClient(t)
|
|
||||||
type fields struct {
|
|
||||||
baseClient KeyVaultClient
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{client}, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &KeyVault{
|
|
||||||
baseClient: tt.fields.baseClient,
|
|
||||||
}
|
|
||||||
if err := k.Close(); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("KeyVault.Close() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_keyType_KeyType(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
Kty keyvault.JSONWebKeyType
|
|
||||||
Curve keyvault.JSONWebKeyCurveName
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
pl apiv1.ProtectionLevel
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want keyvault.JSONWebKeyType
|
|
||||||
}{
|
|
||||||
{"ec", fields{keyvault.EC, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.EC},
|
|
||||||
{"ec software", fields{keyvault.EC, keyvault.P384}, args{apiv1.Software}, keyvault.EC},
|
|
||||||
{"ec hsm", fields{keyvault.EC, keyvault.P521}, args{apiv1.HSM}, keyvault.ECHSM},
|
|
||||||
{"rsa", fields{keyvault.RSA, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.RSA},
|
|
||||||
{"rsa software", fields{keyvault.RSA, ""}, args{apiv1.Software}, keyvault.RSA},
|
|
||||||
{"rsa hsm", fields{keyvault.RSA, ""}, args{apiv1.HSM}, keyvault.RSAHSM},
|
|
||||||
{"empty", fields{"FOO", ""}, args{apiv1.UnspecifiedProtectionLevel}, ""},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := keyType{
|
|
||||||
Kty: tt.fields.Kty,
|
|
||||||
Curve: tt.fields.Curve,
|
|
||||||
}
|
|
||||||
if got := k.KeyType(tt.args.pl); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("keyType.KeyType() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeyVault_ValidateName(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
s string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{"azurekms:name=my-key;vault=my-vault"}, false},
|
|
||||||
{"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true"}, false},
|
|
||||||
{"fail scheme", args{"azure:name=my-key;vault=my-vault"}, true},
|
|
||||||
{"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault"}, true},
|
|
||||||
{"fail no name", args{"azurekms:vault=my-vault"}, true},
|
|
||||||
{"fail no vault", args{"azurekms:name=my-key"}, true},
|
|
||||||
{"fail empty", args{""}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &KeyVault{}
|
|
||||||
if err := k.ValidateName(tt.args.s); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("KeyVault.ValidateName() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,182 +0,0 @@
|
||||||
package azurekms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/base64"
|
|
||||||
"io"
|
|
||||||
"math/big"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
|
||||||
"github.com/Azure/go-autorest/autorest/azure"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"golang.org/x/crypto/cryptobyte"
|
|
||||||
"golang.org/x/crypto/cryptobyte/asn1"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Signer implements a crypto.Signer using the AWS KMS.
|
|
||||||
type Signer struct {
|
|
||||||
client KeyVaultClient
|
|
||||||
vaultBaseURL string
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
publicKey crypto.PublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSigner creates a new signer using a key in the AWS KMS.
|
|
||||||
func NewSigner(client KeyVaultClient, signingKey string, defaults DefaultOptions) (crypto.Signer, error) {
|
|
||||||
vault, name, version, _, err := parseKeyName(signingKey, defaults)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure that the key exists.
|
|
||||||
signer := &Signer{
|
|
||||||
client: client,
|
|
||||||
vaultBaseURL: vaultBaseURL(vault),
|
|
||||||
name: name,
|
|
||||||
version: version,
|
|
||||||
}
|
|
||||||
if err := signer.preloadKey(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return signer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Signer) preloadKey() error {
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
resp, err := s.client.GetKey(ctx, s.vaultBaseURL, s.name, s.version)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "keyVault GetKey failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.publicKey, err = convertKey(resp.Key)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Public returns the public key of this signer or an error.
|
|
||||||
func (s *Signer) Public() crypto.PublicKey {
|
|
||||||
return s.publicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign signs digest with the private key stored in the AWS KMS.
|
|
||||||
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
|
||||||
alg, err := getSigningAlgorithm(s.Public(), opts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
b64 := base64.RawURLEncoding.EncodeToString(digest)
|
|
||||||
|
|
||||||
// Sign with retry if the key is not ready
|
|
||||||
resp, err := s.signWithRetry(alg, b64, 3)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "keyVault Sign failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := base64.RawURLEncoding.DecodeString(*resp.Result)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error decoding keyVault Sign result")
|
|
||||||
}
|
|
||||||
|
|
||||||
var octetSize int
|
|
||||||
switch alg {
|
|
||||||
case keyvault.ES256:
|
|
||||||
octetSize = 32 // 256-bit, concat(R,S) = 64 bytes
|
|
||||||
case keyvault.ES384:
|
|
||||||
octetSize = 48 // 384-bit, concat(R,S) = 96 bytes
|
|
||||||
case keyvault.ES512:
|
|
||||||
octetSize = 66 // 528-bit, concat(R,S) = 132 bytes
|
|
||||||
default:
|
|
||||||
return sig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to asn1
|
|
||||||
if len(sig) != octetSize*2 {
|
|
||||||
return nil, errors.Errorf("keyVault Sign failed: unexpected signature length")
|
|
||||||
}
|
|
||||||
var b cryptobyte.Builder
|
|
||||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
|
||||||
b.AddASN1BigInt(new(big.Int).SetBytes(sig[:octetSize])) // R
|
|
||||||
b.AddASN1BigInt(new(big.Int).SetBytes(sig[octetSize:])) // S
|
|
||||||
})
|
|
||||||
return b.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Signer) signWithRetry(alg keyvault.JSONWebKeySignatureAlgorithm, b64 string, retryAttempts int) (keyvault.KeyOperationResult, error) {
|
|
||||||
retry:
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{
|
|
||||||
Algorithm: alg,
|
|
||||||
Value: &b64,
|
|
||||||
})
|
|
||||||
if err != nil && retryAttempts > 0 {
|
|
||||||
var requestError *azure.RequestError
|
|
||||||
if errors.As(err, &requestError) {
|
|
||||||
if se := requestError.ServiceError; se != nil && se.InnerError != nil {
|
|
||||||
code, ok := se.InnerError["code"].(string)
|
|
||||||
if ok && code == "KeyNotYetValid" {
|
|
||||||
time.Sleep(time.Second / time.Duration(retryAttempts))
|
|
||||||
retryAttempts--
|
|
||||||
goto retry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) {
|
|
||||||
switch key.(type) {
|
|
||||||
case *rsa.PublicKey:
|
|
||||||
hashFunc := opts.HashFunc()
|
|
||||||
pss, isPSS := opts.(*rsa.PSSOptions)
|
|
||||||
// Random salt lengths are not supported
|
|
||||||
if isPSS &&
|
|
||||||
pss.SaltLength != rsa.PSSSaltLengthAuto &&
|
|
||||||
pss.SaltLength != rsa.PSSSaltLengthEqualsHash &&
|
|
||||||
pss.SaltLength != hashFunc.Size() {
|
|
||||||
return "", errors.Errorf("unsupported RSA-PSS salt length %d", pss.SaltLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch h := hashFunc; h {
|
|
||||||
case crypto.SHA256:
|
|
||||||
if isPSS {
|
|
||||||
return keyvault.PS256, nil
|
|
||||||
}
|
|
||||||
return keyvault.RS256, nil
|
|
||||||
case crypto.SHA384:
|
|
||||||
if isPSS {
|
|
||||||
return keyvault.PS384, nil
|
|
||||||
}
|
|
||||||
return keyvault.RS384, nil
|
|
||||||
case crypto.SHA512:
|
|
||||||
if isPSS {
|
|
||||||
return keyvault.PS512, nil
|
|
||||||
}
|
|
||||||
return keyvault.RS512, nil
|
|
||||||
default:
|
|
||||||
return "", errors.Errorf("unsupported hash function %v", h)
|
|
||||||
}
|
|
||||||
case *ecdsa.PublicKey:
|
|
||||||
switch h := opts.HashFunc(); h {
|
|
||||||
case crypto.SHA256:
|
|
||||||
return keyvault.ES256, nil
|
|
||||||
case crypto.SHA384:
|
|
||||||
return keyvault.ES384, nil
|
|
||||||
case crypto.SHA512:
|
|
||||||
return keyvault.ES512, nil
|
|
||||||
default:
|
|
||||||
return "", errors.Errorf("unsupported hash function %v", h)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return "", errors.Errorf("unsupported key type %T", key)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,493 +0,0 @@
|
||||||
package azurekms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/base64"
|
|
||||||
"io"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
|
||||||
"github.com/Azure/go-autorest/autorest"
|
|
||||||
"github.com/Azure/go-autorest/autorest/azure"
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"go.step.sm/crypto/keyutil"
|
|
||||||
"golang.org/x/crypto/cryptobyte"
|
|
||||||
"golang.org/x/crypto/cryptobyte/asn1"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewSigner(t *testing.T) {
|
|
||||||
key, err := keyutil.GenerateDefaultSigner()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pub := key.Public()
|
|
||||||
jwk := createJWK(t, pub)
|
|
||||||
|
|
||||||
client := mockClient(t)
|
|
||||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
|
|
||||||
Key: jwk,
|
|
||||||
}, nil)
|
|
||||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
|
||||||
Key: jwk,
|
|
||||||
}, nil)
|
|
||||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
|
||||||
Key: jwk,
|
|
||||||
}, nil)
|
|
||||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
|
|
||||||
|
|
||||||
var noOptions DefaultOptions
|
|
||||||
type args struct {
|
|
||||||
client KeyVaultClient
|
|
||||||
signingKey string
|
|
||||||
defaults DefaultOptions
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want crypto.Signer
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{client, "azurekms:vault=my-vault;name=my-key", noOptions}, &Signer{
|
|
||||||
client: client,
|
|
||||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
|
||||||
name: "my-key",
|
|
||||||
version: "",
|
|
||||||
publicKey: pub,
|
|
||||||
}, false},
|
|
||||||
{"ok with version", args{client, "azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, &Signer{
|
|
||||||
client: client,
|
|
||||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
|
||||||
name: "my-key",
|
|
||||||
version: "my-version",
|
|
||||||
publicKey: pub,
|
|
||||||
}, false},
|
|
||||||
{"ok with options", args{client, "azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault", ProtectionLevel: apiv1.HSM}}, &Signer{
|
|
||||||
client: client,
|
|
||||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
|
||||||
name: "my-key",
|
|
||||||
version: "my-version",
|
|
||||||
publicKey: pub,
|
|
||||||
}, false},
|
|
||||||
{"fail GetKey", args{client, "azurekms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true},
|
|
||||||
{"fail vault", args{client, "azurekms:name=not-found;vault=", noOptions}, nil, true},
|
|
||||||
{"fail id", args{client, "azurekms:name=;vault=my-vault?version=my-version", noOptions}, nil, true},
|
|
||||||
{"fail scheme", args{client, "kms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := NewSigner(tt.args.client, tt.args.signingKey, tt.args.defaults)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSigner_Public(t *testing.T) {
|
|
||||||
key, err := keyutil.GenerateDefaultSigner()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pub := key.Public()
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
publicKey crypto.PublicKey
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want crypto.PublicKey
|
|
||||||
}{
|
|
||||||
{"ok", fields{pub}, pub},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Signer{
|
|
||||||
publicKey: tt.fields.publicKey,
|
|
||||||
}
|
|
||||||
if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Signer.Public() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSigner_Sign(t *testing.T) {
|
|
||||||
sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) {
|
|
||||||
key, err := keyutil.GenerateSigner(kty, crv, bits)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
h := opts.HashFunc().New()
|
|
||||||
h.Write([]byte("random-data"))
|
|
||||||
sum := h.Sum(nil)
|
|
||||||
|
|
||||||
var sig, resultSig []byte
|
|
||||||
if priv, ok := key.(*ecdsa.PrivateKey); ok {
|
|
||||||
r, s, err := ecdsa.Sign(rand.Reader, priv, sum)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
curveBits := priv.Params().BitSize
|
|
||||||
keyBytes := curveBits / 8
|
|
||||||
if curveBits%8 > 0 {
|
|
||||||
keyBytes++
|
|
||||||
}
|
|
||||||
rBytes := r.Bytes()
|
|
||||||
rBytesPadded := make([]byte, keyBytes)
|
|
||||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
|
||||||
|
|
||||||
sBytes := s.Bytes()
|
|
||||||
sBytesPadded := make([]byte, keyBytes)
|
|
||||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
|
||||||
// nolint:gocritic
|
|
||||||
resultSig = append(rBytesPadded, sBytesPadded...)
|
|
||||||
|
|
||||||
var b cryptobyte.Builder
|
|
||||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
|
||||||
b.AddASN1BigInt(r)
|
|
||||||
b.AddASN1BigInt(s)
|
|
||||||
})
|
|
||||||
sig, err = b.Bytes()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sig, err = key.Sign(rand.Reader, sum, opts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
resultSig = sig
|
|
||||||
}
|
|
||||||
|
|
||||||
return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig
|
|
||||||
}
|
|
||||||
|
|
||||||
p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256)
|
|
||||||
p384, p384Digest, p386ResultSig, p384Sig := sign("EC", "P-384", 0, crypto.SHA384)
|
|
||||||
p521, p521Digest, p521ResultSig, p521Sig := sign("EC", "P-521", 0, crypto.SHA512)
|
|
||||||
rsaSHA256, rsaSHA256Digest, rsaSHA256ResultSig, rsaSHA256Sig := sign("RSA", "", 2048, crypto.SHA256)
|
|
||||||
rsaSHA384, rsaSHA384Digest, rsaSHA384ResultSig, rsaSHA384Sig := sign("RSA", "", 2048, crypto.SHA384)
|
|
||||||
rsaSHA512, rsaSHA512Digest, rsaSHA512ResultSig, rsaSHA512Sig := sign("RSA", "", 2048, crypto.SHA512)
|
|
||||||
rsaPSSSHA256, rsaPSSSHA256Digest, rsaPSSSHA256ResultSig, rsaPSSSHA256Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
|
|
||||||
SaltLength: rsa.PSSSaltLengthAuto,
|
|
||||||
Hash: crypto.SHA256,
|
|
||||||
})
|
|
||||||
rsaPSSSHA384, rsaPSSSHA384Digest, rsaPSSSHA384ResultSig, rsaPSSSHA384Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
|
|
||||||
SaltLength: rsa.PSSSaltLengthAuto,
|
|
||||||
Hash: crypto.SHA512,
|
|
||||||
})
|
|
||||||
rsaPSSSHA512, rsaPSSSHA512Digest, rsaPSSSHA512ResultSig, rsaPSSSHA512Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
|
|
||||||
SaltLength: rsa.PSSSaltLengthAuto,
|
|
||||||
Hash: crypto.SHA512,
|
|
||||||
})
|
|
||||||
|
|
||||||
ed25519Key, err := keyutil.GenerateSigner("OKP", "Ed25519", 0)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client := mockClient(t)
|
|
||||||
expects := []struct {
|
|
||||||
name string
|
|
||||||
keyVersion string
|
|
||||||
alg keyvault.JSONWebKeySignatureAlgorithm
|
|
||||||
digest []byte
|
|
||||||
result keyvault.KeyOperationResult
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
{"P-256", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
|
|
||||||
Result: &p256ResultSig,
|
|
||||||
}, nil},
|
|
||||||
{"P-384", "my-version", keyvault.ES384, p384Digest, keyvault.KeyOperationResult{
|
|
||||||
Result: &p386ResultSig,
|
|
||||||
}, nil},
|
|
||||||
{"P-521", "my-version", keyvault.ES512, p521Digest, keyvault.KeyOperationResult{
|
|
||||||
Result: &p521ResultSig,
|
|
||||||
}, nil},
|
|
||||||
{"RSA SHA256", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{
|
|
||||||
Result: &rsaSHA256ResultSig,
|
|
||||||
}, nil},
|
|
||||||
{"RSA SHA384", "", keyvault.RS384, rsaSHA384Digest, keyvault.KeyOperationResult{
|
|
||||||
Result: &rsaSHA384ResultSig,
|
|
||||||
}, nil},
|
|
||||||
{"RSA SHA512", "", keyvault.RS512, rsaSHA512Digest, keyvault.KeyOperationResult{
|
|
||||||
Result: &rsaSHA512ResultSig,
|
|
||||||
}, nil},
|
|
||||||
{"RSA-PSS SHA256", "", keyvault.PS256, rsaPSSSHA256Digest, keyvault.KeyOperationResult{
|
|
||||||
Result: &rsaPSSSHA256ResultSig,
|
|
||||||
}, nil},
|
|
||||||
{"RSA-PSS SHA384", "", keyvault.PS384, rsaPSSSHA384Digest, keyvault.KeyOperationResult{
|
|
||||||
Result: &rsaPSSSHA384ResultSig,
|
|
||||||
}, nil},
|
|
||||||
{"RSA-PSS SHA512", "", keyvault.PS512, rsaPSSSHA512Digest, keyvault.KeyOperationResult{
|
|
||||||
Result: &rsaPSSSHA512ResultSig,
|
|
||||||
}, nil},
|
|
||||||
// Errors
|
|
||||||
{"fail Sign", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{}, errTest},
|
|
||||||
{"fail sign length", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
|
|
||||||
Result: &rsaSHA256ResultSig,
|
|
||||||
}, nil},
|
|
||||||
{"fail base64", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
|
|
||||||
Result: func() *string {
|
|
||||||
v := "😎"
|
|
||||||
return &v
|
|
||||||
}(),
|
|
||||||
}, nil},
|
|
||||||
}
|
|
||||||
for _, e := range expects {
|
|
||||||
value := base64.RawURLEncoding.EncodeToString(e.digest)
|
|
||||||
client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{
|
|
||||||
Algorithm: e.alg,
|
|
||||||
Value: &value,
|
|
||||||
}).Return(e.result, e.err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
client KeyVaultClient
|
|
||||||
vaultBaseURL string
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
publicKey crypto.PublicKey
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
rand io.Reader
|
|
||||||
digest []byte
|
|
||||||
opts crypto.SignerOpts
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want []byte
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok P-256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
|
||||||
rand.Reader, p256Digest, crypto.SHA256,
|
|
||||||
}, p256Sig, false},
|
|
||||||
{"ok P-384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p384}, args{
|
|
||||||
rand.Reader, p384Digest, crypto.SHA384,
|
|
||||||
}, p384Sig, false},
|
|
||||||
{"ok P-521", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p521}, args{
|
|
||||||
rand.Reader, p521Digest, crypto.SHA512,
|
|
||||||
}, p521Sig, false},
|
|
||||||
{"ok RSA SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
|
|
||||||
rand.Reader, rsaSHA256Digest, crypto.SHA256,
|
|
||||||
}, rsaSHA256Sig, false},
|
|
||||||
{"ok RSA SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA384}, args{
|
|
||||||
rand.Reader, rsaSHA384Digest, crypto.SHA384,
|
|
||||||
}, rsaSHA384Sig, false},
|
|
||||||
{"ok RSA SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA512}, args{
|
|
||||||
rand.Reader, rsaSHA512Digest, crypto.SHA512,
|
|
||||||
}, rsaSHA512Sig, false},
|
|
||||||
{"ok RSA-PSS SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{
|
|
||||||
rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{
|
|
||||||
SaltLength: rsa.PSSSaltLengthAuto,
|
|
||||||
Hash: crypto.SHA256,
|
|
||||||
},
|
|
||||||
}, rsaPSSSHA256Sig, false},
|
|
||||||
{"ok RSA-PSS SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA384}, args{
|
|
||||||
rand.Reader, rsaPSSSHA384Digest, &rsa.PSSOptions{
|
|
||||||
SaltLength: rsa.PSSSaltLengthEqualsHash,
|
|
||||||
Hash: crypto.SHA384,
|
|
||||||
},
|
|
||||||
}, rsaPSSSHA384Sig, false},
|
|
||||||
{"ok RSA-PSS SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA512}, args{
|
|
||||||
rand.Reader, rsaPSSSHA512Digest, &rsa.PSSOptions{
|
|
||||||
SaltLength: 64,
|
|
||||||
Hash: crypto.SHA512,
|
|
||||||
},
|
|
||||||
}, rsaPSSSHA512Sig, false},
|
|
||||||
{"fail Sign", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
|
|
||||||
rand.Reader, rsaSHA256Digest, crypto.SHA256,
|
|
||||||
}, nil, true},
|
|
||||||
{"fail sign length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
|
||||||
rand.Reader, p256Digest, crypto.SHA256,
|
|
||||||
}, nil, true},
|
|
||||||
{"fail base64", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
|
||||||
rand.Reader, p256Digest, crypto.SHA256,
|
|
||||||
}, nil, true},
|
|
||||||
{"fail RSA-PSS salt length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{
|
|
||||||
rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{
|
|
||||||
SaltLength: 64,
|
|
||||||
Hash: crypto.SHA256,
|
|
||||||
},
|
|
||||||
}, nil, true},
|
|
||||||
{"fail RSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
|
|
||||||
rand.Reader, rsaSHA256Digest, crypto.SHA1,
|
|
||||||
}, nil, true},
|
|
||||||
{"fail ECDSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
|
||||||
rand.Reader, p256Digest, crypto.MD5,
|
|
||||||
}, nil, true},
|
|
||||||
{"fail Ed25519", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", ed25519Key}, args{
|
|
||||||
rand.Reader, []byte("message"), crypto.Hash(0),
|
|
||||||
}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Signer{
|
|
||||||
client: tt.fields.client,
|
|
||||||
vaultBaseURL: tt.fields.vaultBaseURL,
|
|
||||||
name: tt.fields.name,
|
|
||||||
version: tt.fields.version,
|
|
||||||
publicKey: tt.fields.publicKey,
|
|
||||||
}
|
|
||||||
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSigner_Sign_signWithRetry(t *testing.T) {
|
|
||||||
sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) {
|
|
||||||
key, err := keyutil.GenerateSigner(kty, crv, bits)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
h := opts.HashFunc().New()
|
|
||||||
h.Write([]byte("random-data"))
|
|
||||||
sum := h.Sum(nil)
|
|
||||||
|
|
||||||
var sig, resultSig []byte
|
|
||||||
if priv, ok := key.(*ecdsa.PrivateKey); ok {
|
|
||||||
r, s, err := ecdsa.Sign(rand.Reader, priv, sum)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
curveBits := priv.Params().BitSize
|
|
||||||
keyBytes := curveBits / 8
|
|
||||||
if curveBits%8 > 0 {
|
|
||||||
keyBytes++
|
|
||||||
}
|
|
||||||
rBytes := r.Bytes()
|
|
||||||
rBytesPadded := make([]byte, keyBytes)
|
|
||||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
|
||||||
|
|
||||||
sBytes := s.Bytes()
|
|
||||||
sBytesPadded := make([]byte, keyBytes)
|
|
||||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
|
||||||
// nolint:gocritic
|
|
||||||
resultSig = append(rBytesPadded, sBytesPadded...)
|
|
||||||
|
|
||||||
var b cryptobyte.Builder
|
|
||||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
|
||||||
b.AddASN1BigInt(r)
|
|
||||||
b.AddASN1BigInt(s)
|
|
||||||
})
|
|
||||||
sig, err = b.Bytes()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sig, err = key.Sign(rand.Reader, sum, opts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
resultSig = sig
|
|
||||||
}
|
|
||||||
|
|
||||||
return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig
|
|
||||||
}
|
|
||||||
|
|
||||||
p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256)
|
|
||||||
okResult := keyvault.KeyOperationResult{
|
|
||||||
Result: &p256ResultSig,
|
|
||||||
}
|
|
||||||
failResult := keyvault.KeyOperationResult{}
|
|
||||||
retryError := autorest.DetailedError{
|
|
||||||
Original: &azure.RequestError{
|
|
||||||
ServiceError: &azure.ServiceError{
|
|
||||||
InnerError: map[string]interface{}{
|
|
||||||
"code": "KeyNotYetValid",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
client := mockClient(t)
|
|
||||||
expects := []struct {
|
|
||||||
name string
|
|
||||||
keyVersion string
|
|
||||||
alg keyvault.JSONWebKeySignatureAlgorithm
|
|
||||||
digest []byte
|
|
||||||
result keyvault.KeyOperationResult
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
{"ok 1", "", keyvault.ES256, p256Digest, failResult, retryError},
|
|
||||||
{"ok 2", "", keyvault.ES256, p256Digest, failResult, retryError},
|
|
||||||
{"ok 3", "", keyvault.ES256, p256Digest, failResult, retryError},
|
|
||||||
{"ok 4", "", keyvault.ES256, p256Digest, okResult, nil},
|
|
||||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
|
||||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
|
||||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
|
||||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
|
||||||
}
|
|
||||||
for _, e := range expects {
|
|
||||||
value := base64.RawURLEncoding.EncodeToString(e.digest)
|
|
||||||
client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{
|
|
||||||
Algorithm: e.alg,
|
|
||||||
Value: &value,
|
|
||||||
}).Return(e.result, e.err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
client KeyVaultClient
|
|
||||||
vaultBaseURL string
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
publicKey crypto.PublicKey
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
rand io.Reader
|
|
||||||
digest []byte
|
|
||||||
opts crypto.SignerOpts
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want []byte
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
|
||||||
rand.Reader, p256Digest, crypto.SHA256,
|
|
||||||
}, p256Sig, false},
|
|
||||||
{"fail", fields{client, "https://my-vault.vault.azure.net/", "my-key", "fail-version", p256}, args{
|
|
||||||
rand.Reader, p256Digest, crypto.SHA256,
|
|
||||||
}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Signer{
|
|
||||||
client: tt.fields.client,
|
|
||||||
vaultBaseURL: tt.fields.vaultBaseURL,
|
|
||||||
name: tt.fields.name,
|
|
||||||
version: tt.fields.version,
|
|
||||||
publicKey: tt.fields.publicKey,
|
|
||||||
}
|
|
||||||
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,98 +0,0 @@
|
||||||
package azurekms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"encoding/json"
|
|
||||||
"net/url"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"github.com/smallstep/certificates/kms/uri"
|
|
||||||
"go.step.sm/crypto/jose"
|
|
||||||
)
|
|
||||||
|
|
||||||
// defaultContext returns the default context used in requests to azure.
|
|
||||||
func defaultContext() (context.Context, context.CancelFunc) {
|
|
||||||
return context.WithTimeout(context.Background(), 15*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getKeyName returns the uri of the key vault key.
|
|
||||||
func getKeyName(vault, name string, bundle keyvault.KeyBundle) string {
|
|
||||||
if bundle.Key != nil && bundle.Key.Kid != nil {
|
|
||||||
sm := keyIDRegexp.FindAllStringSubmatch(*bundle.Key.Kid, 1)
|
|
||||||
if len(sm) == 1 && len(sm[0]) == 4 {
|
|
||||||
m := sm[0]
|
|
||||||
u := uri.New(Scheme, url.Values{
|
|
||||||
"vault": []string{m[1]},
|
|
||||||
"name": []string{m[2]},
|
|
||||||
})
|
|
||||||
u.RawQuery = url.Values{"version": []string{m[3]}}.Encode()
|
|
||||||
return u.String()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Fallback to URI without id.
|
|
||||||
return uri.New(Scheme, url.Values{
|
|
||||||
"vault": []string{vault},
|
|
||||||
"name": []string{name},
|
|
||||||
}).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseKeyName returns the key vault, name and version from URIs like:
|
|
||||||
//
|
|
||||||
// - azurekms:vault=key-vault;name=key-name
|
|
||||||
// - azurekms:vault=key-vault;name=key-name?version=key-id
|
|
||||||
// - azurekms:vault=key-vault;name=key-name?version=key-id&hsm=true
|
|
||||||
//
|
|
||||||
// The key-id defines the version of the key, if it is not passed the latest
|
|
||||||
// version will be used.
|
|
||||||
//
|
|
||||||
// HSM can also be passed to define the protection level if this is not given in
|
|
||||||
// CreateQuery.
|
|
||||||
func parseKeyName(rawURI string, defaults DefaultOptions) (vault, name, version string, hsm bool, err error) {
|
|
||||||
var u *uri.URI
|
|
||||||
|
|
||||||
u, err = uri.ParseWithScheme(Scheme, rawURI)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if name = u.Get("name"); name == "" {
|
|
||||||
err = errors.Errorf("key uri %s is not valid: name is missing", rawURI)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if vault = u.Get("vault"); vault == "" {
|
|
||||||
if defaults.Vault == "" {
|
|
||||||
name = ""
|
|
||||||
err = errors.Errorf("key uri %s is not valid: vault is missing", rawURI)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
vault = defaults.Vault
|
|
||||||
}
|
|
||||||
if u.Get("hsm") == "" {
|
|
||||||
hsm = (defaults.ProtectionLevel == apiv1.HSM)
|
|
||||||
} else {
|
|
||||||
hsm = u.GetBool("hsm")
|
|
||||||
}
|
|
||||||
|
|
||||||
version = u.Get("version")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func vaultBaseURL(vault string) string {
|
|
||||||
return "https://" + vault + ".vault.azure.net/"
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertKey(key *keyvault.JSONWebKey) (crypto.PublicKey, error) {
|
|
||||||
b, err := json.Marshal(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error marshaling key")
|
|
||||||
}
|
|
||||||
var jwk jose.JSONWebKey
|
|
||||||
if err := jwk.UnmarshalJSON(b); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error unmarshaling key")
|
|
||||||
}
|
|
||||||
return jwk.Key, nil
|
|
||||||
}
|
|
|
@ -1,96 +0,0 @@
|
||||||
package azurekms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_getKeyName(t *testing.T) {
|
|
||||||
getBundle := func(kid string) keyvault.KeyBundle {
|
|
||||||
return keyvault.KeyBundle{
|
|
||||||
Key: &keyvault.JSONWebKey{
|
|
||||||
Kid: &kid,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
vault string
|
|
||||||
name string
|
|
||||||
bundle keyvault.KeyBundle
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"ok", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault?version=my-version"},
|
|
||||||
{"ok default", args{"my-vault", "my-key", getBundle("https://my-vault.foo.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault"},
|
|
||||||
{"ok too short", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-version")}, "azurekms:name=my-key;vault=my-vault"},
|
|
||||||
{"ok too long", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version/sign")}, "azurekms:name=my-key;vault=my-vault"},
|
|
||||||
{"ok nil key", args{"my-vault", "my-key", keyvault.KeyBundle{}}, "azurekms:name=my-key;vault=my-vault"},
|
|
||||||
{"ok nil kid", args{"my-vault", "my-key", keyvault.KeyBundle{Key: &keyvault.JSONWebKey{}}}, "azurekms:name=my-key;vault=my-vault"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := getKeyName(tt.args.vault, tt.args.name, tt.args.bundle); got != tt.want {
|
|
||||||
t.Errorf("getKeyName() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_parseKeyName(t *testing.T) {
|
|
||||||
var noOptions DefaultOptions
|
|
||||||
type args struct {
|
|
||||||
rawURI string
|
|
||||||
defaults DefaultOptions
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantVault string
|
|
||||||
wantName string
|
|
||||||
wantVersion string
|
|
||||||
wantHsm bool
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{"azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false},
|
|
||||||
{"ok opaque version", args{"azurekms:name=my-key;vault=my-vault;version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false},
|
|
||||||
{"ok no version", args{"azurekms:name=my-key;vault=my-vault", noOptions}, "my-vault", "my-key", "", false, false},
|
|
||||||
{"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true", noOptions}, "my-vault", "my-key", "", true, false},
|
|
||||||
{"ok hsm false", args{"azurekms:name=my-key;vault=my-vault?hsm=false", noOptions}, "my-vault", "my-key", "", false, false},
|
|
||||||
{"ok default vault", args{"azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault"}}, "my-vault", "my-key", "my-version", false, false},
|
|
||||||
{"ok default hsm", args{"azurekms:name=my-key;vault=my-vault?version=my-version", DefaultOptions{Vault: "other-vault", ProtectionLevel: apiv1.HSM}}, "my-vault", "my-key", "my-version", true, false},
|
|
||||||
{"fail scheme", args{"azure:name=my-key;vault=my-vault", noOptions}, "", "", "", false, true},
|
|
||||||
{"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault", noOptions}, "", "", "", false, true},
|
|
||||||
{"fail no name", args{"azurekms:vault=my-vault", noOptions}, "", "", "", false, true},
|
|
||||||
{"fail empty name", args{"azurekms:name=;vault=my-vault", noOptions}, "", "", "", false, true},
|
|
||||||
{"fail no vault", args{"azurekms:name=my-key", noOptions}, "", "", "", false, true},
|
|
||||||
{"fail empty vault", args{"azurekms:name=my-key;vault=", noOptions}, "", "", "", false, true},
|
|
||||||
{"fail empty", args{"", noOptions}, "", "", "", false, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
gotVault, gotName, gotVersion, gotHsm, err := parseKeyName(tt.args.rawURI, tt.args.defaults)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("parseKeyName() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if gotVault != tt.wantVault {
|
|
||||||
t.Errorf("parseKeyName() gotVault = %v, want %v", gotVault, tt.wantVault)
|
|
||||||
}
|
|
||||||
if gotName != tt.wantName {
|
|
||||||
t.Errorf("parseKeyName() gotName = %v, want %v", gotName, tt.wantName)
|
|
||||||
}
|
|
||||||
if gotVersion != tt.wantVersion {
|
|
||||||
t.Errorf("parseKeyName() gotVersion = %v, want %v", gotVersion, tt.wantVersion)
|
|
||||||
}
|
|
||||||
if gotHsm != tt.wantHsm {
|
|
||||||
t.Errorf("parseKeyName() gotHsm = %v, want %v", gotHsm, tt.wantHsm)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,348 +0,0 @@
|
||||||
package cloudkms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/x509"
|
|
||||||
"log"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
cloudkms "cloud.google.com/go/kms/apiv1"
|
|
||||||
gax "github.com/googleapis/gax-go/v2"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"github.com/smallstep/certificates/kms/uri"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
"google.golang.org/api/option"
|
|
||||||
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Scheme is the scheme used in uris.
|
|
||||||
const Scheme = "cloudkms"
|
|
||||||
|
|
||||||
const pendingGenerationRetries = 10
|
|
||||||
|
|
||||||
// protectionLevelMapping maps step protection levels with cloud kms ones.
|
|
||||||
var protectionLevelMapping = map[apiv1.ProtectionLevel]kmspb.ProtectionLevel{
|
|
||||||
apiv1.UnspecifiedProtectionLevel: kmspb.ProtectionLevel_PROTECTION_LEVEL_UNSPECIFIED,
|
|
||||||
apiv1.Software: kmspb.ProtectionLevel_SOFTWARE,
|
|
||||||
apiv1.HSM: kmspb.ProtectionLevel_HSM,
|
|
||||||
}
|
|
||||||
|
|
||||||
// signatureAlgorithmMapping is a mapping between the step signature algorithm,
|
|
||||||
// and bits for RSA keys, with cloud kms one.
|
|
||||||
//
|
|
||||||
// Cloud KMS does not support SHA384WithRSA, SHA384WithRSAPSS, SHA384WithRSAPSS,
|
|
||||||
// ECDSAWithSHA512, and PureEd25519.
|
|
||||||
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]interface{}{
|
|
||||||
apiv1.UnspecifiedSignAlgorithm: kmspb.CryptoKeyVersion_CRYPTO_KEY_VERSION_ALGORITHM_UNSPECIFIED,
|
|
||||||
apiv1.SHA256WithRSA: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
|
|
||||||
0: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256,
|
|
||||||
2048: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256,
|
|
||||||
3072: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256,
|
|
||||||
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256,
|
|
||||||
},
|
|
||||||
apiv1.SHA512WithRSA: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
|
|
||||||
0: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512,
|
|
||||||
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512,
|
|
||||||
},
|
|
||||||
apiv1.SHA256WithRSAPSS: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
|
|
||||||
0: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256,
|
|
||||||
2048: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256,
|
|
||||||
3072: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256,
|
|
||||||
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256,
|
|
||||||
},
|
|
||||||
apiv1.SHA512WithRSAPSS: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{
|
|
||||||
0: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512,
|
|
||||||
4096: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512,
|
|
||||||
},
|
|
||||||
apiv1.ECDSAWithSHA256: kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256,
|
|
||||||
apiv1.ECDSAWithSHA384: kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384,
|
|
||||||
}
|
|
||||||
|
|
||||||
var cryptoKeyVersionMapping = map[kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm]x509.SignatureAlgorithm{
|
|
||||||
kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256: x509.ECDSAWithSHA256,
|
|
||||||
kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384: x509.ECDSAWithSHA384,
|
|
||||||
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256: x509.SHA256WithRSA,
|
|
||||||
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256: x509.SHA256WithRSA,
|
|
||||||
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256: x509.SHA256WithRSA,
|
|
||||||
kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512: x509.SHA512WithRSA,
|
|
||||||
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256: x509.SHA256WithRSAPSS,
|
|
||||||
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256: x509.SHA256WithRSAPSS,
|
|
||||||
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256: x509.SHA256WithRSAPSS,
|
|
||||||
kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512: x509.SHA512WithRSAPSS,
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeyManagementClient defines the methods on KeyManagementClient that this
|
|
||||||
// package will use. This interface will be used for unit testing.
|
|
||||||
type KeyManagementClient interface {
|
|
||||||
Close() error
|
|
||||||
GetPublicKey(context.Context, *kmspb.GetPublicKeyRequest, ...gax.CallOption) (*kmspb.PublicKey, error)
|
|
||||||
AsymmetricSign(context.Context, *kmspb.AsymmetricSignRequest, ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error)
|
|
||||||
CreateCryptoKey(context.Context, *kmspb.CreateCryptoKeyRequest, ...gax.CallOption) (*kmspb.CryptoKey, error)
|
|
||||||
GetKeyRing(context.Context, *kmspb.GetKeyRingRequest, ...gax.CallOption) (*kmspb.KeyRing, error)
|
|
||||||
CreateKeyRing(context.Context, *kmspb.CreateKeyRingRequest, ...gax.CallOption) (*kmspb.KeyRing, error)
|
|
||||||
CreateCryptoKeyVersion(ctx context.Context, req *kmspb.CreateCryptoKeyVersionRequest, opts ...gax.CallOption) (*kmspb.CryptoKeyVersion, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
var newKeyManagementClient = func(ctx context.Context, opts ...option.ClientOption) (KeyManagementClient, error) {
|
|
||||||
return cloudkms.NewKeyManagementClient(ctx, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloudKMS implements a KMS using Google's Cloud apiv1.
|
|
||||||
type CloudKMS struct {
|
|
||||||
client KeyManagementClient
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new CloudKMS configured with a new client.
|
|
||||||
func New(ctx context.Context, opts apiv1.Options) (*CloudKMS, error) {
|
|
||||||
var cloudOpts []option.ClientOption
|
|
||||||
|
|
||||||
if opts.URI != "" {
|
|
||||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if f := u.Get("credentials-file"); f != "" {
|
|
||||||
cloudOpts = append(cloudOpts, option.WithCredentialsFile(f))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated way to set configuration parameters.
|
|
||||||
if opts.CredentialsFile != "" {
|
|
||||||
cloudOpts = append(cloudOpts, option.WithCredentialsFile(opts.CredentialsFile))
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := newKeyManagementClient(ctx, cloudOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &CloudKMS{
|
|
||||||
client: client,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
apiv1.Register(apiv1.CloudKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
|
||||||
return New(ctx, opts)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCloudKMS creates a CloudKMS with a given client.
|
|
||||||
func NewCloudKMS(client KeyManagementClient) *CloudKMS {
|
|
||||||
return &CloudKMS{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the connection of the Cloud KMS client.
|
|
||||||
func (k *CloudKMS) Close() error {
|
|
||||||
if err := k.client.Close(); err != nil {
|
|
||||||
return errors.Wrap(err, "cloudKMS Close failed")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSigner returns a new cloudkms signer configured with the given signing
|
|
||||||
// key name.
|
|
||||||
func (k *CloudKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
|
||||||
if req.SigningKey == "" {
|
|
||||||
return nil, errors.New("signing key cannot be empty")
|
|
||||||
}
|
|
||||||
return NewSigner(k.client, req.SigningKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKey creates in Google's Cloud KMS a new asymmetric key for signing.
|
|
||||||
func (k *CloudKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
|
||||||
if req.Name == "" {
|
|
||||||
return nil, errors.New("createKeyRequest 'name' cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
protectionLevel, ok := protectionLevelMapping[req.ProtectionLevel]
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("cloudKMS does not support protection level '%s'", req.ProtectionLevel)
|
|
||||||
}
|
|
||||||
|
|
||||||
var signatureAlgorithm kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm
|
|
||||||
v, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("cloudKMS does not support signature algorithm '%s'", req.SignatureAlgorithm)
|
|
||||||
}
|
|
||||||
switch v := v.(type) {
|
|
||||||
case kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm:
|
|
||||||
signatureAlgorithm = v
|
|
||||||
case map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm:
|
|
||||||
if signatureAlgorithm, ok = v[req.Bits]; !ok {
|
|
||||||
return nil, errors.Errorf("cloudKMS does not support signature algorithm '%s' with '%d' bits", req.SignatureAlgorithm, req.Bits)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, errors.Errorf("unexpected error: this should not happen")
|
|
||||||
}
|
|
||||||
|
|
||||||
var crytoKeyName string
|
|
||||||
|
|
||||||
// Split `projects/PROJECT_ID/locations/global/keyRings/RING_ID/cryptoKeys/KEY_ID`
|
|
||||||
// to `projects/PROJECT_ID/locations/global/keyRings/RING_ID` and `KEY_ID`.
|
|
||||||
keyRing, keyID := Parent(req.Name)
|
|
||||||
if err := k.createKeyRingIfNeeded(keyRing); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Create private key in CloudKMS.
|
|
||||||
response, err := k.client.CreateCryptoKey(ctx, &kmspb.CreateCryptoKeyRequest{
|
|
||||||
Parent: keyRing,
|
|
||||||
CryptoKeyId: keyID,
|
|
||||||
CryptoKey: &kmspb.CryptoKey{
|
|
||||||
Purpose: kmspb.CryptoKey_ASYMMETRIC_SIGN,
|
|
||||||
VersionTemplate: &kmspb.CryptoKeyVersionTemplate{
|
|
||||||
ProtectionLevel: protectionLevel,
|
|
||||||
Algorithm: signatureAlgorithm,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
if status.Code(err) != codes.AlreadyExists {
|
|
||||||
return nil, errors.Wrap(err, "cloudKMS CreateCryptoKey failed")
|
|
||||||
}
|
|
||||||
// Create a new version if the key already exists.
|
|
||||||
//
|
|
||||||
// Note that it will have the same purpose, protection level and
|
|
||||||
// algorithm than as previous one.
|
|
||||||
req := &kmspb.CreateCryptoKeyVersionRequest{
|
|
||||||
Parent: req.Name,
|
|
||||||
CryptoKeyVersion: &kmspb.CryptoKeyVersion{
|
|
||||||
State: kmspb.CryptoKeyVersion_ENABLED,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
response, err := k.client.CreateCryptoKeyVersion(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "cloudKMS CreateCryptoKeyVersion failed")
|
|
||||||
}
|
|
||||||
crytoKeyName = response.Name
|
|
||||||
} else {
|
|
||||||
crytoKeyName = response.Name + "/cryptoKeyVersions/1"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sleep deterministically to avoid retries because of PENDING_GENERATING.
|
|
||||||
// One second is often enough.
|
|
||||||
if protectionLevel == kmspb.ProtectionLevel_HSM {
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve public key to add it to the response.
|
|
||||||
pk, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: crytoKeyName,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "cloudKMS GetPublicKey failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &apiv1.CreateKeyResponse{
|
|
||||||
Name: crytoKeyName,
|
|
||||||
PublicKey: pk,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: crytoKeyName,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *CloudKMS) createKeyRingIfNeeded(name string) error {
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
_, err := k.client.GetKeyRing(ctx, &kmspb.GetKeyRingRequest{
|
|
||||||
Name: name,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
parent, child := Parent(name)
|
|
||||||
_, err = k.client.CreateKeyRing(ctx, &kmspb.CreateKeyRingRequest{
|
|
||||||
Parent: parent,
|
|
||||||
KeyRingId: child,
|
|
||||||
})
|
|
||||||
if err != nil && status.Code(err) != codes.AlreadyExists {
|
|
||||||
return errors.Wrap(err, "cloudKMS CreateKeyRing failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPublicKey gets from Google's Cloud KMS a public key by name. Key names
|
|
||||||
// follow the pattern:
|
|
||||||
//
|
|
||||||
// projects/([^/]+)/locations/([a-zA-Z0-9_-]{1,63})/keyRings/([a-zA-Z0-9_-]{1,63})/cryptoKeys/([a-zA-Z0-9_-]{1,63})/cryptoKeyVersions/([a-zA-Z0-9_-]{1,63})
|
|
||||||
func (k *CloudKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
|
||||||
if req.Name == "" {
|
|
||||||
return nil, errors.New("createKeyRequest 'name' cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := k.getPublicKeyWithRetries(req.Name, pendingGenerationRetries)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "cloudKMS GetPublicKey failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
pk, err := pemutil.ParseKey([]byte(response.Pem))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return pk, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPublicKeyWithRetries retries the request if the error is
|
|
||||||
// FailedPrecondition, caused because the key is in the PENDING_GENERATION
|
|
||||||
// status.
|
|
||||||
func (k *CloudKMS) getPublicKeyWithRetries(name string, retries int) (response *kmspb.PublicKey, err error) {
|
|
||||||
workFn := func() (*kmspb.PublicKey, error) {
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
return k.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{
|
|
||||||
Name: name,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
for i := 0; i < retries; i++ {
|
|
||||||
if response, err = workFn(); err == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if status.Code(err) == codes.FailedPrecondition {
|
|
||||||
log.Println("Waiting for key generation ...")
|
|
||||||
time.Sleep(time.Duration(i+1) * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultContext() (context.Context, context.CancelFunc) {
|
|
||||||
return context.WithTimeout(context.Background(), 15*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parent splits a string in the format `key/value/key2/value2` in a parent and
|
|
||||||
// child, for the previous string it will return `key/value` and `value2`.
|
|
||||||
func Parent(name string) (string, string) {
|
|
||||||
a, b := parent(name)
|
|
||||||
a, _ = parent(a)
|
|
||||||
return a, b
|
|
||||||
}
|
|
||||||
|
|
||||||
func parent(name string) (string, string) {
|
|
||||||
i := strings.LastIndex(name, "/")
|
|
||||||
switch i {
|
|
||||||
case -1:
|
|
||||||
return "", name
|
|
||||||
case 0:
|
|
||||||
return "", name[i+1:]
|
|
||||||
default:
|
|
||||||
return name[:i], name[i+1:]
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,464 +0,0 @@
|
||||||
package cloudkms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
gax "github.com/googleapis/gax-go/v2"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
"google.golang.org/api/option"
|
|
||||||
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParent(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want string
|
|
||||||
want1 string
|
|
||||||
}{
|
|
||||||
{"zero", args{"child"}, "", "child"},
|
|
||||||
{"one", args{"parent/child"}, "", "child"},
|
|
||||||
{"two", args{"grandparent/parent/child"}, "grandparent", "child"},
|
|
||||||
{"three", args{"great-grandparent/grandparent/parent/child"}, "great-grandparent/grandparent", "child"},
|
|
||||||
{"empty", args{""}, "", ""},
|
|
||||||
{"root", args{"/"}, "", ""},
|
|
||||||
{"child", args{"/child"}, "", "child"},
|
|
||||||
{"parent", args{"parent/"}, "", ""},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, got1 := Parent(tt.args.name)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("Parent() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
if got1 != tt.want1 {
|
|
||||||
t.Errorf("Parent() got1 = %v, want %v", got1, tt.want1)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
tmp := newKeyManagementClient
|
|
||||||
t.Cleanup(func() {
|
|
||||||
newKeyManagementClient = tmp
|
|
||||||
})
|
|
||||||
newKeyManagementClient = func(ctx context.Context, opts ...option.ClientOption) (KeyManagementClient, error) {
|
|
||||||
if len(opts) > 0 {
|
|
||||||
return nil, fmt.Errorf("test error")
|
|
||||||
}
|
|
||||||
return &MockClient{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
opts apiv1.Options
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *CloudKMS
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{context.Background(), apiv1.Options{}}, &CloudKMS{client: &MockClient{}}, false},
|
|
||||||
{"ok with uri", args{context.Background(), apiv1.Options{URI: "cloudkms:"}}, &CloudKMS{client: &MockClient{}}, false},
|
|
||||||
{"fail credentials", args{context.Background(), apiv1.Options{CredentialsFile: "testdata/missing"}}, nil, true},
|
|
||||||
{"fail with uri", args{context.Background(), apiv1.Options{URI: "cloudkms:credentials-file=testdata/missing"}}, nil, true},
|
|
||||||
{"fail schema", args{context.Background(), apiv1.Options{URI: "pkcs11:"}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := New(tt.args.ctx, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("New() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew_real(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
opts apiv1.Options
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *CloudKMS
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"fail credentials", args{context.Background(), apiv1.Options{CredentialsFile: "testdata/missing"}}, nil, true},
|
|
||||||
{"fail with uri", args{context.Background(), apiv1.Options{URI: "cloudkms:credentials-file=testdata/missing"}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := New(tt.args.ctx, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("New() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewCloudKMS(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
client KeyManagementClient
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *CloudKMS
|
|
||||||
}{
|
|
||||||
{"ok", args{&MockClient{}}, &CloudKMS{&MockClient{}}},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := NewCloudKMS(tt.args.client); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("NewCloudKMS() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloudKMS_Close(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
client KeyManagementClient
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{&MockClient{close: func() error { return nil }}}, false},
|
|
||||||
{"fail", fields{&MockClient{close: func() error { return fmt.Errorf("an error") }}}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &CloudKMS{
|
|
||||||
client: tt.fields.client,
|
|
||||||
}
|
|
||||||
if err := k.Close(); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("CloudKMS.Close() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloudKMS_CreateSigner(t *testing.T) {
|
|
||||||
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
|
|
||||||
pemBytes, err := os.ReadFile("testdata/pub.pem")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pk, err := pemutil.ParseKey(pemBytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
client KeyManagementClient
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateSignerRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want crypto.Signer
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{&MockClient{
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
|
||||||
},
|
|
||||||
}}, args{&apiv1.CreateSignerRequest{SigningKey: keyName}}, &Signer{client: &MockClient{}, signingKey: keyName, publicKey: pk}, false},
|
|
||||||
{"fail", fields{&MockClient{
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return nil, fmt.Errorf("test error")
|
|
||||||
},
|
|
||||||
}}, args{&apiv1.CreateSignerRequest{SigningKey: ""}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &CloudKMS{
|
|
||||||
client: tt.fields.client,
|
|
||||||
}
|
|
||||||
got, err := k.CreateSigner(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("CloudKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if signer, ok := got.(*Signer); ok {
|
|
||||||
signer.client = &MockClient{}
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("CloudKMS.CreateSigner() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloudKMS_CreateKey(t *testing.T) {
|
|
||||||
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c"
|
|
||||||
testError := fmt.Errorf("an error")
|
|
||||||
alreadyExists := status.Error(codes.AlreadyExists, "already exists")
|
|
||||||
|
|
||||||
pemBytes, err := os.ReadFile("testdata/pub.pem")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pk, err := pemutil.ParseKey(pemBytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var retries int
|
|
||||||
type fields struct {
|
|
||||||
client KeyManagementClient
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateKeyRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want *apiv1.CreateKeyResponse
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{
|
|
||||||
&MockClient{
|
|
||||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return &kmspb.KeyRing{}, nil
|
|
||||||
},
|
|
||||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
|
||||||
return &kmspb.CryptoKey{Name: keyName}, nil
|
|
||||||
},
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
|
||||||
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/1", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/1"}}, false},
|
|
||||||
{"ok new key ring", fields{
|
|
||||||
&MockClient{
|
|
||||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return nil, testError
|
|
||||||
},
|
|
||||||
createKeyRing: func(_ context.Context, _ *kmspb.CreateKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return nil, alreadyExists
|
|
||||||
},
|
|
||||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
|
||||||
return &kmspb.CryptoKey{Name: keyName}, nil
|
|
||||||
},
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.Software, SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 3072}},
|
|
||||||
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/1", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/1"}}, false},
|
|
||||||
{"ok new key version", fields{
|
|
||||||
&MockClient{
|
|
||||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return &kmspb.KeyRing{}, nil
|
|
||||||
},
|
|
||||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
|
||||||
return nil, alreadyExists
|
|
||||||
},
|
|
||||||
createCryptoKeyVersion: func(_ context.Context, _ *kmspb.CreateCryptoKeyVersionRequest, _ ...gax.CallOption) (*kmspb.CryptoKeyVersion, error) {
|
|
||||||
return &kmspb.CryptoKeyVersion{Name: keyName + "/cryptoKeyVersions/2"}, nil
|
|
||||||
},
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
|
||||||
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/2", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/2"}}, false},
|
|
||||||
{"ok with retries", fields{
|
|
||||||
&MockClient{
|
|
||||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return &kmspb.KeyRing{}, nil
|
|
||||||
},
|
|
||||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
|
||||||
return &kmspb.CryptoKey{Name: keyName}, nil
|
|
||||||
},
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
if retries != 2 {
|
|
||||||
retries++
|
|
||||||
return nil, status.Error(codes.FailedPrecondition, "key is not enabled, current state is: PENDING_GENERATION")
|
|
||||||
}
|
|
||||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
|
||||||
&apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/1", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/1"}}, false},
|
|
||||||
{"fail name", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{}}, nil, true},
|
|
||||||
{"fail protection level", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.ProtectionLevel(100)}}, nil, true},
|
|
||||||
{"fail signature algorithm", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.Software, SignatureAlgorithm: apiv1.SignatureAlgorithm(100)}}, nil, true},
|
|
||||||
{"fail number of bits", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.Software, SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 1024}},
|
|
||||||
nil, true},
|
|
||||||
{"fail create key ring", fields{
|
|
||||||
&MockClient{
|
|
||||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return nil, testError
|
|
||||||
},
|
|
||||||
createKeyRing: func(_ context.Context, _ *kmspb.CreateKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return nil, testError
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
|
||||||
nil, true},
|
|
||||||
{"fail create key", fields{
|
|
||||||
&MockClient{
|
|
||||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return &kmspb.KeyRing{}, nil
|
|
||||||
},
|
|
||||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
|
||||||
return nil, testError
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
|
||||||
nil, true},
|
|
||||||
{"fail create key version", fields{
|
|
||||||
&MockClient{
|
|
||||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return &kmspb.KeyRing{}, nil
|
|
||||||
},
|
|
||||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
|
||||||
return nil, alreadyExists
|
|
||||||
},
|
|
||||||
createCryptoKeyVersion: func(_ context.Context, _ *kmspb.CreateCryptoKeyVersionRequest, _ ...gax.CallOption) (*kmspb.CryptoKeyVersion, error) {
|
|
||||||
return nil, testError
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
|
||||||
nil, true},
|
|
||||||
{"fail get public key", fields{
|
|
||||||
&MockClient{
|
|
||||||
getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return &kmspb.KeyRing{}, nil
|
|
||||||
},
|
|
||||||
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
|
||||||
return &kmspb.CryptoKey{Name: keyName}, nil
|
|
||||||
},
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return nil, testError
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}},
|
|
||||||
nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &CloudKMS{
|
|
||||||
client: tt.fields.client,
|
|
||||||
}
|
|
||||||
got, err := k.CreateKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("CloudKMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("CloudKMS.CreateKey() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloudKMS_GetPublicKey(t *testing.T) {
|
|
||||||
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
|
|
||||||
testError := fmt.Errorf("an error")
|
|
||||||
|
|
||||||
pemBytes, err := os.ReadFile("testdata/pub.pem")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pk, err := pemutil.ParseKey(pemBytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var retries int
|
|
||||||
type fields struct {
|
|
||||||
client KeyManagementClient
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.GetPublicKeyRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want crypto.PublicKey
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", fields{
|
|
||||||
&MockClient{
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, pk, false},
|
|
||||||
{"ok with retries", fields{
|
|
||||||
&MockClient{
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
if retries != 2 {
|
|
||||||
retries++
|
|
||||||
return nil, status.Error(codes.FailedPrecondition, "key is not enabled, current state is: PENDING_GENERATION")
|
|
||||||
}
|
|
||||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, pk, false},
|
|
||||||
{"fail name", fields{&MockClient{}}, args{&apiv1.GetPublicKeyRequest{}}, nil, true},
|
|
||||||
{"fail get public key", fields{
|
|
||||||
&MockClient{
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return nil, testError
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, nil, true},
|
|
||||||
{"fail parse pem", fields{
|
|
||||||
&MockClient{
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return &kmspb.PublicKey{Pem: string("bad pem")}, nil
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &CloudKMS{
|
|
||||||
client: tt.fields.client,
|
|
||||||
}
|
|
||||||
got, err := k.GetPublicKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("CloudKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("CloudKMS.GetPublicKey() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,46 +0,0 @@
|
||||||
package cloudkms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
gax "github.com/googleapis/gax-go/v2"
|
|
||||||
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockClient struct {
|
|
||||||
close func() error
|
|
||||||
getPublicKey func(context.Context, *kmspb.GetPublicKeyRequest, ...gax.CallOption) (*kmspb.PublicKey, error)
|
|
||||||
asymmetricSign func(context.Context, *kmspb.AsymmetricSignRequest, ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error)
|
|
||||||
createCryptoKey func(context.Context, *kmspb.CreateCryptoKeyRequest, ...gax.CallOption) (*kmspb.CryptoKey, error)
|
|
||||||
getKeyRing func(context.Context, *kmspb.GetKeyRingRequest, ...gax.CallOption) (*kmspb.KeyRing, error)
|
|
||||||
createKeyRing func(context.Context, *kmspb.CreateKeyRingRequest, ...gax.CallOption) (*kmspb.KeyRing, error)
|
|
||||||
createCryptoKeyVersion func(context.Context, *kmspb.CreateCryptoKeyVersionRequest, ...gax.CallOption) (*kmspb.CryptoKeyVersion, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) Close() error {
|
|
||||||
return m.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) GetPublicKey(ctx context.Context, req *kmspb.GetPublicKeyRequest, opts ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return m.getPublicKey(ctx, req, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) AsymmetricSign(ctx context.Context, req *kmspb.AsymmetricSignRequest, opts ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error) {
|
|
||||||
return m.asymmetricSign(ctx, req, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) CreateCryptoKey(ctx context.Context, req *kmspb.CreateCryptoKeyRequest, opts ...gax.CallOption) (*kmspb.CryptoKey, error) {
|
|
||||||
return m.createCryptoKey(ctx, req, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) GetKeyRing(ctx context.Context, req *kmspb.GetKeyRingRequest, opts ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return m.getKeyRing(ctx, req, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) CreateKeyRing(ctx context.Context, req *kmspb.CreateKeyRingRequest, opts ...gax.CallOption) (*kmspb.KeyRing, error) {
|
|
||||||
return m.createKeyRing(ctx, req, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) CreateCryptoKeyVersion(ctx context.Context, req *kmspb.CreateCryptoKeyVersionRequest, opts ...gax.CallOption) (*kmspb.CryptoKeyVersion, error) {
|
|
||||||
return m.createCryptoKeyVersion(ctx, req, opts...)
|
|
||||||
}
|
|
|
@ -1,95 +0,0 @@
|
||||||
package cloudkms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/x509"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Signer implements a crypto.Signer using Google's Cloud KMS.
|
|
||||||
type Signer struct {
|
|
||||||
client KeyManagementClient
|
|
||||||
signingKey string
|
|
||||||
algorithm x509.SignatureAlgorithm
|
|
||||||
publicKey crypto.PublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSigner creates a new crypto.Signer the given CloudKMS signing key.
|
|
||||||
func NewSigner(c KeyManagementClient, signingKey string) (*Signer, error) {
|
|
||||||
// Make sure that the key exists.
|
|
||||||
signer := &Signer{
|
|
||||||
client: c,
|
|
||||||
signingKey: signingKey,
|
|
||||||
}
|
|
||||||
if err := signer.preloadKey(signingKey); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return signer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Signer) preloadKey(signingKey string) error {
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
response, err := s.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{
|
|
||||||
Name: signingKey,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "cloudKMS GetPublicKey failed")
|
|
||||||
}
|
|
||||||
s.algorithm = cryptoKeyVersionMapping[response.Algorithm]
|
|
||||||
s.publicKey, err = pemutil.ParseKey([]byte(response.Pem))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Public returns the public key of this signer or an error.
|
|
||||||
func (s *Signer) Public() crypto.PublicKey {
|
|
||||||
return s.publicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign signs digest with the private key stored in Google's Cloud KMS.
|
|
||||||
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
|
||||||
req := &kmspb.AsymmetricSignRequest{
|
|
||||||
Name: s.signingKey,
|
|
||||||
Digest: &kmspb.Digest{},
|
|
||||||
}
|
|
||||||
|
|
||||||
switch h := opts.HashFunc(); h {
|
|
||||||
case crypto.SHA256:
|
|
||||||
req.Digest.Digest = &kmspb.Digest_Sha256{
|
|
||||||
Sha256: digest,
|
|
||||||
}
|
|
||||||
case crypto.SHA384:
|
|
||||||
req.Digest.Digest = &kmspb.Digest_Sha384{
|
|
||||||
Sha384: digest,
|
|
||||||
}
|
|
||||||
case crypto.SHA512:
|
|
||||||
req.Digest.Digest = &kmspb.Digest_Sha512{
|
|
||||||
Sha512: digest,
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, errors.Errorf("unsupported hash function %v", h)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
response, err := s.client.AsymmetricSign(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "cloudKMS AsymmetricSign failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return response.Signature, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignatureAlgorithm returns the algorithm that must be specified in a
|
|
||||||
// certificate to sign. This is specially important to distinguish RSA and
|
|
||||||
// RSAPSS schemas.
|
|
||||||
func (s *Signer) SignatureAlgorithm() x509.SignatureAlgorithm {
|
|
||||||
return s.algorithm
|
|
||||||
}
|
|
|
@ -1,235 +0,0 @@
|
||||||
package cloudkms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/x509"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
gax "github.com/googleapis/gax-go/v2"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_newSigner(t *testing.T) {
|
|
||||||
pemBytes, err := os.ReadFile("testdata/pub.pem")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pk, err := pemutil.ParseKey(pemBytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
c KeyManagementClient
|
|
||||||
signingKey string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *Signer
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{&MockClient{
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
|
|
||||||
},
|
|
||||||
}, "signingKey"}, &Signer{client: &MockClient{}, signingKey: "signingKey", publicKey: pk}, false},
|
|
||||||
{"fail get public key", args{&MockClient{
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return nil, fmt.Errorf("an error")
|
|
||||||
},
|
|
||||||
}, "signingKey"}, nil, true},
|
|
||||||
{"fail parse pem", args{&MockClient{
|
|
||||||
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
return &kmspb.PublicKey{Pem: string("bad pem")}, nil
|
|
||||||
},
|
|
||||||
}, "signingKey"}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := NewSigner(tt.args.c, tt.args.signingKey)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != nil {
|
|
||||||
got.client = &MockClient{}
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_signer_Public(t *testing.T) {
|
|
||||||
pemBytes, err := os.ReadFile("testdata/pub.pem")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pk, err := pemutil.ParseKey(pemBytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
client KeyManagementClient
|
|
||||||
signingKey string
|
|
||||||
publicKey crypto.PublicKey
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want crypto.PublicKey
|
|
||||||
}{
|
|
||||||
{"ok", fields{&MockClient{}, "signingKey", pk}, pk},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Signer{
|
|
||||||
client: tt.fields.client,
|
|
||||||
signingKey: tt.fields.signingKey,
|
|
||||||
publicKey: tt.fields.publicKey,
|
|
||||||
}
|
|
||||||
if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("signer.Public() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_signer_Sign(t *testing.T) {
|
|
||||||
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
|
|
||||||
okClient := &MockClient{
|
|
||||||
asymmetricSign: func(_ context.Context, _ *kmspb.AsymmetricSignRequest, _ ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error) {
|
|
||||||
return &kmspb.AsymmetricSignResponse{Signature: []byte("ok signature")}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
failClient := &MockClient{
|
|
||||||
asymmetricSign: func(_ context.Context, _ *kmspb.AsymmetricSignRequest, _ ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error) {
|
|
||||||
return nil, fmt.Errorf("an error")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
client KeyManagementClient
|
|
||||||
signingKey string
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
rand io.Reader
|
|
||||||
digest []byte
|
|
||||||
opts crypto.SignerOpts
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want []byte
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok sha256", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA256}, []byte("ok signature"), false},
|
|
||||||
{"ok sha384", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA384}, []byte("ok signature"), false},
|
|
||||||
{"ok sha512", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA512}, []byte("ok signature"), false},
|
|
||||||
{"fail MD5", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
|
|
||||||
{"fail asymmetric sign", fields{failClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Signer{
|
|
||||||
client: tt.fields.client,
|
|
||||||
signingKey: tt.fields.signingKey,
|
|
||||||
}
|
|
||||||
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("signer.Sign() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSigner_SignatureAlgorithm(t *testing.T) {
|
|
||||||
pemBytes, err := os.ReadFile("testdata/pub.pem")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &MockClient{
|
|
||||||
getPublicKey: func(_ context.Context, req *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
|
|
||||||
var algorithm kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm
|
|
||||||
switch req.Name {
|
|
||||||
case "ECDSA-SHA256":
|
|
||||||
algorithm = kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256
|
|
||||||
case "ECDSA-SHA384":
|
|
||||||
algorithm = kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384
|
|
||||||
case "SHA256-RSA-2048":
|
|
||||||
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256
|
|
||||||
case "SHA256-RSA-3072":
|
|
||||||
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256
|
|
||||||
case "SHA256-RSA-4096":
|
|
||||||
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256
|
|
||||||
case "SHA512-RSA-4096":
|
|
||||||
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512
|
|
||||||
case "SHA256-RSAPSS-2048":
|
|
||||||
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256
|
|
||||||
case "SHA256-RSAPSS-3072":
|
|
||||||
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256
|
|
||||||
case "SHA256-RSAPSS-4096":
|
|
||||||
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256
|
|
||||||
case "SHA512-RSAPSS-4096":
|
|
||||||
algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512
|
|
||||||
}
|
|
||||||
return &kmspb.PublicKey{
|
|
||||||
Pem: string(pemBytes),
|
|
||||||
Algorithm: algorithm,
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
client KeyManagementClient
|
|
||||||
signingKey string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want x509.SignatureAlgorithm
|
|
||||||
}{
|
|
||||||
{"ECDSA-SHA256", fields{client, "ECDSA-SHA256"}, x509.ECDSAWithSHA256},
|
|
||||||
{"ECDSA-SHA384", fields{client, "ECDSA-SHA384"}, x509.ECDSAWithSHA384},
|
|
||||||
{"SHA256-RSA-2048", fields{client, "SHA256-RSA-2048"}, x509.SHA256WithRSA},
|
|
||||||
{"SHA256-RSA-3072", fields{client, "SHA256-RSA-3072"}, x509.SHA256WithRSA},
|
|
||||||
{"SHA256-RSA-4096", fields{client, "SHA256-RSA-4096"}, x509.SHA256WithRSA},
|
|
||||||
{"SHA512-RSA-4096", fields{client, "SHA512-RSA-4096"}, x509.SHA512WithRSA},
|
|
||||||
{"SHA256-RSAPSS-2048", fields{client, "SHA256-RSAPSS-2048"}, x509.SHA256WithRSAPSS},
|
|
||||||
{"SHA256-RSAPSS-3072", fields{client, "SHA256-RSAPSS-3072"}, x509.SHA256WithRSAPSS},
|
|
||||||
{"SHA256-RSAPSS-4096", fields{client, "SHA256-RSAPSS-4096"}, x509.SHA256WithRSAPSS},
|
|
||||||
{"SHA512-RSAPSS-4096", fields{client, "SHA512-RSAPSS-4096"}, x509.SHA512WithRSAPSS},
|
|
||||||
{"unknown", fields{client, "UNKNOWN"}, x509.UnknownSignatureAlgorithm},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
signer, err := NewSigner(tt.fields.client, tt.fields.signingKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("NewSigner() error = %v", err)
|
|
||||||
}
|
|
||||||
if got := signer.SignatureAlgorithm(); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Signer.SignatureAlgorithm() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
4
kms/cloudkms/testdata/pub.pem
vendored
4
kms/cloudkms/testdata/pub.pem
vendored
|
@ -1,4 +0,0 @@
|
||||||
-----BEGIN PUBLIC KEY-----
|
|
||||||
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE5VPD/W5RXn0lrs2MdoNteTSZ+sh1
|
|
||||||
veT13hakPZF9YzaNVZgujqK3d1nt+4jPECU+ED/WQ1GgFZiVGUo3flvB/w==
|
|
||||||
-----END PUBLIC KEY-----
|
|
43
kms/kms.go
43
kms/kms.go
|
@ -1,43 +0,0 @@
|
||||||
package kms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
|
|
||||||
// Enable default implementation
|
|
||||||
"github.com/smallstep/certificates/kms/softkms"
|
|
||||||
)
|
|
||||||
|
|
||||||
// KeyManager is the interface implemented by all the KMS.
|
|
||||||
type KeyManager = apiv1.KeyManager
|
|
||||||
|
|
||||||
// CertificateManager is the interface implemented by the KMS that can load and
|
|
||||||
// store x509.Certificates.
|
|
||||||
type CertificateManager = apiv1.CertificateManager
|
|
||||||
|
|
||||||
// Options are the KMS options. They represent the kms object in the ca.json.
|
|
||||||
type Options = apiv1.Options
|
|
||||||
|
|
||||||
// Default is the implementation of the default KMS.
|
|
||||||
var Default = &softkms.SoftKMS{}
|
|
||||||
|
|
||||||
// New initializes a new KMS from the given type.
|
|
||||||
func New(ctx context.Context, opts apiv1.Options) (KeyManager, error) {
|
|
||||||
if err := opts.Validate(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t := apiv1.Type(strings.ToLower(opts.Type))
|
|
||||||
if t == apiv1.DefaultKMS {
|
|
||||||
t = apiv1.SoftKMS
|
|
||||||
}
|
|
||||||
|
|
||||||
fn, ok := apiv1.LoadKeyManagerNewFunc(t)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("unsupported kms type '%s'", t)
|
|
||||||
}
|
|
||||||
return fn(ctx, opts)
|
|
||||||
}
|
|
|
@ -1,52 +0,0 @@
|
||||||
package kms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"github.com/smallstep/certificates/kms/awskms"
|
|
||||||
"github.com/smallstep/certificates/kms/cloudkms"
|
|
||||||
"github.com/smallstep/certificates/kms/softkms"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
opts apiv1.Options
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
skipOnCI bool
|
|
||||||
args args
|
|
||||||
want KeyManager
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"softkms", false, args{ctx, apiv1.Options{Type: "softkms"}}, &softkms.SoftKMS{}, false},
|
|
||||||
{"default", false, args{ctx, apiv1.Options{}}, &softkms.SoftKMS{}, false},
|
|
||||||
{"awskms", false, args{ctx, apiv1.Options{Type: "awskms"}}, &awskms.KMS{}, false},
|
|
||||||
{"cloudkms", true, args{ctx, apiv1.Options{Type: "cloudkms"}}, &cloudkms.CloudKMS{}, true}, // fails because not credentials
|
|
||||||
{"pkcs11", false, args{ctx, apiv1.Options{Type: "pkcs11"}}, nil, true}, // not yet supported
|
|
||||||
{"fail validation", false, args{ctx, apiv1.Options{Type: "foobar"}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.skipOnCI && os.Getenv("CI") == "true" {
|
|
||||||
t.SkipNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
got, err := New(tt.args.ctx, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if reflect.TypeOf(got) != reflect.TypeOf(tt.want) {
|
|
||||||
t.Errorf("New() = %T, want %T", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,83 +0,0 @@
|
||||||
//go:build cgo
|
|
||||||
// +build cgo
|
|
||||||
|
|
||||||
package pkcs11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
)
|
|
||||||
|
|
||||||
func benchmarkSign(b *testing.B, signer crypto.Signer, opts crypto.SignerOpts) {
|
|
||||||
hash := opts.HashFunc()
|
|
||||||
h := hash.New()
|
|
||||||
h.Write([]byte("buggy-coheir-RUBRIC-rabbet-liberal-eaglet-khartoum-stagger"))
|
|
||||||
digest := h.Sum(nil)
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
signer.Sign(rand.Reader, digest, opts)
|
|
||||||
}
|
|
||||||
b.StopTimer()
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkSignRSA(b *testing.B) {
|
|
||||||
k := setupPKCS11(b)
|
|
||||||
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "pkcs11:id=7371;object=rsa-key",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("PKCS11.CreateSigner() error = %v", err)
|
|
||||||
}
|
|
||||||
benchmarkSign(b, signer, crypto.SHA256)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkSignRSAPSS(b *testing.B) {
|
|
||||||
k := setupPKCS11(b)
|
|
||||||
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "pkcs11:id=7372;object=rsa-pss-key",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("PKCS11.CreateSigner() error = %v", err)
|
|
||||||
}
|
|
||||||
benchmarkSign(b, signer, &rsa.PSSOptions{
|
|
||||||
SaltLength: rsa.PSSSaltLengthEqualsHash,
|
|
||||||
Hash: crypto.SHA256,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkSignP256(b *testing.B) {
|
|
||||||
k := setupPKCS11(b)
|
|
||||||
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "pkcs11:id=7373;object=ecdsa-p256-key",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("PKCS11.CreateSigner() error = %v", err)
|
|
||||||
}
|
|
||||||
benchmarkSign(b, signer, crypto.SHA256)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkSignP384(b *testing.B) {
|
|
||||||
k := setupPKCS11(b)
|
|
||||||
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "pkcs11:id=7374;object=ecdsa-p384-key",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("PKCS11.CreateSigner() error = %v", err)
|
|
||||||
}
|
|
||||||
benchmarkSign(b, signer, crypto.SHA384)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkSignP521(b *testing.B) {
|
|
||||||
k := setupPKCS11(b)
|
|
||||||
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "pkcs11:id=7375;object=ecdsa-p521-key",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("PKCS11.CreateSigner() error = %v", err)
|
|
||||||
}
|
|
||||||
benchmarkSign(b, signer, crypto.SHA512)
|
|
||||||
}
|
|
|
@ -1,64 +0,0 @@
|
||||||
//go:build opensc
|
|
||||||
// +build opensc
|
|
||||||
|
|
||||||
package pkcs11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/ThalesIgnite/crypto11"
|
|
||||||
)
|
|
||||||
|
|
||||||
var softHSM2Once sync.Once
|
|
||||||
|
|
||||||
// mustPKCS11 configures a *PKCS11 KMS to be used with OpenSC, using for example
|
|
||||||
// a Nitrokey HSM. To initialize these tests we should run:
|
|
||||||
//
|
|
||||||
// sc-hsm-tool --initialize --so-pin 3537363231383830 --pin 123456
|
|
||||||
//
|
|
||||||
// Or:
|
|
||||||
//
|
|
||||||
// pkcs11-tool --module /usr/local/lib/opensc-pkcs11.so \
|
|
||||||
// --init-token --init-pin \
|
|
||||||
// --so-pin=3537363231383830 --new-pin=123456 --pin=123456 \
|
|
||||||
// --label="pkcs11-test"
|
|
||||||
func mustPKCS11(t TBTesting) *PKCS11 {
|
|
||||||
t.Helper()
|
|
||||||
testModule = "OpenSC"
|
|
||||||
if runtime.GOARCH != "amd64" {
|
|
||||||
t.Fatalf("opensc test skipped on %s:%s", runtime.GOOS, runtime.GOARCH)
|
|
||||||
}
|
|
||||||
|
|
||||||
var path string
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "darwin":
|
|
||||||
path = "/usr/local/lib/opensc-pkcs11.so"
|
|
||||||
case "linux":
|
|
||||||
path = "/usr/local/lib/opensc-pkcs11.so"
|
|
||||||
default:
|
|
||||||
t.Skipf("opensc test skipped on %s", runtime.GOOS)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var zero int
|
|
||||||
p11, err := crypto11.Configure(&crypto11.Config{
|
|
||||||
Path: path,
|
|
||||||
SlotNumber: &zero,
|
|
||||||
Pin: "123456",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to configure opensc on %s: %v", runtime.GOOS, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
k := &PKCS11{
|
|
||||||
p11: p11,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup
|
|
||||||
softHSM2Once.Do(func() {
|
|
||||||
teardown(t, k)
|
|
||||||
setup(t, k)
|
|
||||||
})
|
|
||||||
|
|
||||||
return k
|
|
||||||
}
|
|
|
@ -1,210 +0,0 @@
|
||||||
//go:build cgo && !softhsm2 && !yubihsm2 && !opensc
|
|
||||||
// +build cgo,!softhsm2,!yubihsm2,!opensc
|
|
||||||
|
|
||||||
package pkcs11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"io"
|
|
||||||
"math/big"
|
|
||||||
|
|
||||||
"github.com/ThalesIgnite/crypto11"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
func mustPKCS11(t TBTesting) *PKCS11 {
|
|
||||||
t.Helper()
|
|
||||||
testModule = "Golang crypto"
|
|
||||||
k := &PKCS11{
|
|
||||||
p11: &stubPKCS11{
|
|
||||||
signerIndex: make(map[keyType]int),
|
|
||||||
certIndex: make(map[keyType]int),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for i := range testCerts {
|
|
||||||
testCerts[i].Certificates = nil
|
|
||||||
}
|
|
||||||
teardown(t, k)
|
|
||||||
setup(t, k)
|
|
||||||
return k
|
|
||||||
}
|
|
||||||
|
|
||||||
type keyType struct {
|
|
||||||
id string
|
|
||||||
label string
|
|
||||||
serial string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newKey(id, label []byte, serial *big.Int) keyType {
|
|
||||||
var serialString string
|
|
||||||
if serial != nil {
|
|
||||||
serialString = serial.String()
|
|
||||||
}
|
|
||||||
return keyType{
|
|
||||||
id: string(id),
|
|
||||||
label: string(label),
|
|
||||||
serial: serialString,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type stubPKCS11 struct {
|
|
||||||
signers []crypto11.Signer
|
|
||||||
certs []*x509.Certificate
|
|
||||||
signerIndex map[keyType]int
|
|
||||||
certIndex map[keyType]int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPKCS11) FindKeyPair(id, label []byte) (crypto11.Signer, error) {
|
|
||||||
if id == nil && label == nil {
|
|
||||||
return nil, errors.New("id and label cannot both be nil")
|
|
||||||
}
|
|
||||||
if i, ok := s.signerIndex[newKey(id, label, nil)]; ok {
|
|
||||||
return s.signers[i], nil
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPKCS11) FindCertificate(id, label []byte, serial *big.Int) (*x509.Certificate, error) {
|
|
||||||
if id == nil && label == nil && serial == nil {
|
|
||||||
return nil, errors.New("id, label and serial cannot both be nil")
|
|
||||||
}
|
|
||||||
if i, ok := s.certIndex[newKey(id, label, serial)]; ok {
|
|
||||||
return s.certs[i], nil
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPKCS11) ImportCertificateWithAttributes(template crypto11.AttributeSet, cert *x509.Certificate) error {
|
|
||||||
var id, label []byte
|
|
||||||
if v := template[crypto11.CkaId]; v != nil {
|
|
||||||
id = v.Value
|
|
||||||
}
|
|
||||||
if v := template[crypto11.CkaLabel]; v != nil {
|
|
||||||
label = v.Value
|
|
||||||
}
|
|
||||||
return s.ImportCertificateWithLabel(id, label, cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPKCS11) ImportCertificateWithLabel(id, label []byte, cert *x509.Certificate) error {
|
|
||||||
switch {
|
|
||||||
case id == nil:
|
|
||||||
return errors.New("id cannot both be nil")
|
|
||||||
case label == nil:
|
|
||||||
return errors.New("label cannot both be nil")
|
|
||||||
case cert == nil:
|
|
||||||
return errors.New("certificate cannot be nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
i := len(s.certs)
|
|
||||||
s.certs = append(s.certs, cert)
|
|
||||||
s.certIndex[newKey(id, label, cert.SerialNumber)] = i
|
|
||||||
s.certIndex[newKey(id, nil, nil)] = i
|
|
||||||
s.certIndex[newKey(nil, label, nil)] = i
|
|
||||||
s.certIndex[newKey(nil, nil, cert.SerialNumber)] = i
|
|
||||||
s.certIndex[newKey(id, label, nil)] = i
|
|
||||||
s.certIndex[newKey(id, nil, cert.SerialNumber)] = i
|
|
||||||
s.certIndex[newKey(nil, label, cert.SerialNumber)] = i
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPKCS11) DeleteCertificate(id, label []byte, serial *big.Int) error {
|
|
||||||
if id == nil && label == nil && serial == nil {
|
|
||||||
return errors.New("id, label and serial cannot both be nil")
|
|
||||||
}
|
|
||||||
if i, ok := s.certIndex[newKey(id, label, serial)]; ok {
|
|
||||||
s.certs[i] = nil
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPKCS11) GenerateRSAKeyPairWithAttributes(public, private crypto11.AttributeSet, bits int) (crypto11.SignerDecrypter, error) {
|
|
||||||
var id, label []byte
|
|
||||||
if v := public[crypto11.CkaId]; v != nil {
|
|
||||||
id = v.Value
|
|
||||||
}
|
|
||||||
if v := public[crypto11.CkaLabel]; v != nil {
|
|
||||||
label = v.Value
|
|
||||||
}
|
|
||||||
return s.GenerateRSAKeyPairWithLabel(id, label, bits)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPKCS11) GenerateRSAKeyPairWithLabel(id, label []byte, bits int) (crypto11.SignerDecrypter, error) {
|
|
||||||
if id == nil && label == nil {
|
|
||||||
return nil, errors.New("id and label cannot both be nil")
|
|
||||||
}
|
|
||||||
p, err := rsa.GenerateKey(rand.Reader, bits)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
k := &privateKey{
|
|
||||||
Signer: p,
|
|
||||||
index: len(s.signers),
|
|
||||||
stub: s,
|
|
||||||
}
|
|
||||||
s.signers = append(s.signers, k)
|
|
||||||
s.signerIndex[newKey(id, label, nil)] = k.index
|
|
||||||
s.signerIndex[newKey(id, nil, nil)] = k.index
|
|
||||||
s.signerIndex[newKey(nil, label, nil)] = k.index
|
|
||||||
return k, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPKCS11) GenerateECDSAKeyPairWithAttributes(public, private crypto11.AttributeSet, curve elliptic.Curve) (crypto11.Signer, error) {
|
|
||||||
var id, label []byte
|
|
||||||
if v := public[crypto11.CkaId]; v != nil {
|
|
||||||
id = v.Value
|
|
||||||
}
|
|
||||||
if v := public[crypto11.CkaLabel]; v != nil {
|
|
||||||
label = v.Value
|
|
||||||
}
|
|
||||||
return s.GenerateECDSAKeyPairWithLabel(id, label, curve)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPKCS11) GenerateECDSAKeyPairWithLabel(id, label []byte, curve elliptic.Curve) (crypto11.Signer, error) {
|
|
||||||
if id == nil && label == nil {
|
|
||||||
return nil, errors.New("id and label cannot both be nil")
|
|
||||||
}
|
|
||||||
p, err := ecdsa.GenerateKey(curve, rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
k := &privateKey{
|
|
||||||
Signer: p,
|
|
||||||
index: len(s.signers),
|
|
||||||
stub: s,
|
|
||||||
}
|
|
||||||
s.signers = append(s.signers, k)
|
|
||||||
s.signerIndex[newKey(id, label, nil)] = k.index
|
|
||||||
s.signerIndex[newKey(id, nil, nil)] = k.index
|
|
||||||
s.signerIndex[newKey(nil, label, nil)] = k.index
|
|
||||||
return k, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPKCS11) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type privateKey struct {
|
|
||||||
crypto.Signer
|
|
||||||
index int
|
|
||||||
stub *stubPKCS11
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *privateKey) Delete() error {
|
|
||||||
s.stub.signers[s.index] = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *privateKey) Decrypt(rnd io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
|
|
||||||
k, ok := s.Signer.(*rsa.PrivateKey)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("key is not an rsa key")
|
|
||||||
}
|
|
||||||
return k.Decrypt(rnd, msg, opts)
|
|
||||||
}
|
|
|
@ -1,399 +0,0 @@
|
||||||
//go:build cgo
|
|
||||||
// +build cgo
|
|
||||||
|
|
||||||
package pkcs11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"math/big"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/ThalesIgnite/crypto11"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"github.com/smallstep/certificates/kms/uri"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Scheme is the scheme used in uris.
|
|
||||||
const Scheme = "pkcs11"
|
|
||||||
|
|
||||||
// DefaultRSASize is the number of bits of a new RSA key if no size has been
|
|
||||||
// specified.
|
|
||||||
const DefaultRSASize = 3072
|
|
||||||
|
|
||||||
// P11 defines the methods on crypto11.Context that this package will use. This
|
|
||||||
// interface will be used for unit testing.
|
|
||||||
type P11 interface {
|
|
||||||
FindKeyPair(id, label []byte) (crypto11.Signer, error)
|
|
||||||
FindCertificate(id, label []byte, serial *big.Int) (*x509.Certificate, error)
|
|
||||||
ImportCertificateWithAttributes(template crypto11.AttributeSet, certificate *x509.Certificate) error
|
|
||||||
DeleteCertificate(id, label []byte, serial *big.Int) error
|
|
||||||
GenerateRSAKeyPairWithAttributes(public, private crypto11.AttributeSet, bits int) (crypto11.SignerDecrypter, error)
|
|
||||||
GenerateECDSAKeyPairWithAttributes(public, private crypto11.AttributeSet, curve elliptic.Curve) (crypto11.Signer, error)
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
var p11Configure = func(config *crypto11.Config) (P11, error) {
|
|
||||||
return crypto11.Configure(config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PKCS11 is the implementation of a KMS using the PKCS #11 standard.
|
|
||||||
type PKCS11 struct {
|
|
||||||
p11 P11
|
|
||||||
closed sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
// New returns a new PKCS11 KMS.
|
|
||||||
func New(ctx context.Context, opts apiv1.Options) (*PKCS11, error) {
|
|
||||||
var config crypto11.Config
|
|
||||||
if opts.URI != "" {
|
|
||||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Pin = u.Pin()
|
|
||||||
config.Path = u.Get("module-path")
|
|
||||||
config.TokenLabel = u.Get("token")
|
|
||||||
config.TokenSerial = u.Get("serial")
|
|
||||||
if v := u.Get("slot-id"); v != "" {
|
|
||||||
n, err := strconv.Atoi(v)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "kms uri 'slot-id' is not valid")
|
|
||||||
}
|
|
||||||
config.SlotNumber = &n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if config.Pin == "" && opts.Pin != "" {
|
|
||||||
config.Pin = opts.Pin
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case config.Path == "":
|
|
||||||
return nil, errors.New("kms uri 'module-path' are required")
|
|
||||||
case config.TokenLabel == "" && config.TokenSerial == "" && config.SlotNumber == nil:
|
|
||||||
return nil, errors.New("kms uri 'token', 'serial' or 'slot-id' are required")
|
|
||||||
case config.Pin == "":
|
|
||||||
return nil, errors.New("kms 'pin' cannot be empty")
|
|
||||||
case config.TokenLabel != "" && config.TokenSerial != "":
|
|
||||||
return nil, errors.New("kms uri 'token' and 'serial' are mutually exclusive")
|
|
||||||
case config.TokenLabel != "" && config.SlotNumber != nil:
|
|
||||||
return nil, errors.New("kms uri 'token' and 'slot-id' are mutually exclusive")
|
|
||||||
case config.TokenSerial != "" && config.SlotNumber != nil:
|
|
||||||
return nil, errors.New("kms uri 'serial' and 'slot-id' are mutually exclusive")
|
|
||||||
}
|
|
||||||
|
|
||||||
p11, err := p11Configure(&config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error initializing PKCS#11")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PKCS11{
|
|
||||||
p11: p11,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
apiv1.Register(apiv1.PKCS11, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
|
||||||
return New(ctx, opts)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPublicKey returns the public key ....
|
|
||||||
func (k *PKCS11) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
|
||||||
if req.Name == "" {
|
|
||||||
return nil, errors.New("getPublicKeyRequest 'name' cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, err := findSigner(k.p11, req.Name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getPublicKey failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return signer.Public(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKey generates a new key in the PKCS#11 module and returns the public key.
|
|
||||||
func (k *PKCS11) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
|
||||||
switch {
|
|
||||||
case req.Name == "":
|
|
||||||
return nil, errors.New("createKeyRequest 'name' cannot be empty")
|
|
||||||
case req.Bits < 0:
|
|
||||||
return nil, errors.New("createKeyRequest 'bits' cannot be negative")
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, err := generateKey(k.p11, req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "createKey failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &apiv1.CreateKeyResponse{
|
|
||||||
Name: req.Name,
|
|
||||||
PublicKey: signer.Public(),
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: req.Name,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSigner creates a signer using a key present in the PKCS#11 module.
|
|
||||||
func (k *PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
|
||||||
if req.SigningKey == "" {
|
|
||||||
return nil, errors.New("createSignerRequest 'signingKey' cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, err := findSigner(k.p11, req.SigningKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "createSigner failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return signer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateDecrypter creates a decrypter using a key present in the PKCS#11
|
|
||||||
// module.
|
|
||||||
func (k *PKCS11) CreateDecrypter(req *apiv1.CreateDecrypterRequest) (crypto.Decrypter, error) {
|
|
||||||
if req.DecryptionKey == "" {
|
|
||||||
return nil, errors.New("createDecrypterRequest 'decryptionKey' cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, err := findSigner(k.p11, req.DecryptionKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "createDecrypterRequest failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only RSA keys will implement the Decrypter interface.
|
|
||||||
if _, ok := signer.Public().(*rsa.PublicKey); ok {
|
|
||||||
if dec, ok := signer.(crypto.Decrypter); ok {
|
|
||||||
return dec, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, errors.New("createDecrypterRequest failed: signer does not implement crypto.Decrypter")
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadCertificate implements kms.CertificateManager and loads a certificate
|
|
||||||
// from the YubiKey.
|
|
||||||
func (k *PKCS11) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) {
|
|
||||||
if req.Name == "" {
|
|
||||||
return nil, errors.New("loadCertificateRequest 'name' cannot be nil")
|
|
||||||
}
|
|
||||||
cert, err := findCertificate(k.p11, req.Name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "loadCertificate failed")
|
|
||||||
}
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StoreCertificate implements kms.CertificateManager and stores a certificate
|
|
||||||
// in the YubiKey.
|
|
||||||
func (k *PKCS11) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
|
|
||||||
switch {
|
|
||||||
case req.Name == "":
|
|
||||||
return errors.New("storeCertificateRequest 'name' cannot be empty")
|
|
||||||
case req.Certificate == nil:
|
|
||||||
return errors.New("storeCertificateRequest 'Certificate' cannot be nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
id, object, err := parseObject(req.Name)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "storeCertificate failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enforce the use of both id and labels. This is not strictly necessary in
|
|
||||||
// PKCS #11, but it's a good practice.
|
|
||||||
if len(id) == 0 || len(object) == 0 {
|
|
||||||
return errors.Errorf("key with uri %s is not valid, id and object are required", req.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := k.p11.FindCertificate(id, object, nil)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "storeCertificate failed")
|
|
||||||
}
|
|
||||||
if cert != nil {
|
|
||||||
return errors.Wrap(apiv1.ErrAlreadyExists{
|
|
||||||
Message: req.Name + " already exists",
|
|
||||||
}, "storeCertificate failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Import certificate with the necessary attributes.
|
|
||||||
template, err := crypto11.NewAttributeSetWithIDAndLabel(id, object)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "storeCertificate failed")
|
|
||||||
}
|
|
||||||
if req.Extractable {
|
|
||||||
template.Set(crypto11.CkaExtractable, true)
|
|
||||||
}
|
|
||||||
if err := k.p11.ImportCertificateWithAttributes(template, req.Certificate); err != nil {
|
|
||||||
return errors.Wrap(err, "storeCertificate failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteKey is a utility function to delete a key given an uri.
|
|
||||||
func (k *PKCS11) DeleteKey(u string) error {
|
|
||||||
id, object, err := parseObject(u)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "deleteKey failed")
|
|
||||||
}
|
|
||||||
signer, err := k.p11.FindKeyPair(id, object)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "deleteKey failed")
|
|
||||||
}
|
|
||||||
if signer == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := signer.Delete(); err != nil {
|
|
||||||
return errors.Wrap(err, "deleteKey failed")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteCertificate is a utility function to delete a certificate given an uri.
|
|
||||||
func (k *PKCS11) DeleteCertificate(u string) error {
|
|
||||||
id, object, err := parseObject(u)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "deleteCertificate failed")
|
|
||||||
}
|
|
||||||
if err := k.p11.DeleteCertificate(id, object, nil); err != nil {
|
|
||||||
return errors.Wrap(err, "deleteCertificate failed")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close releases the connection to the PKCS#11 module.
|
|
||||||
func (k *PKCS11) Close() (err error) {
|
|
||||||
k.closed.Do(func() {
|
|
||||||
err = errors.Wrap(k.p11.Close(), "error closing pkcs#11 context")
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func toByte(s string) []byte {
|
|
||||||
if s == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return []byte(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseObject(rawuri string) ([]byte, []byte, error) {
|
|
||||||
u, err := uri.ParseWithScheme(Scheme, rawuri)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
id := u.GetEncoded("id")
|
|
||||||
object := u.Get("object")
|
|
||||||
if len(id) == 0 && object == "" {
|
|
||||||
return nil, nil, errors.Errorf("key with uri %s is not valid, id or object are required", rawuri)
|
|
||||||
}
|
|
||||||
|
|
||||||
return id, toByte(object), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateKey(ctx P11, req *apiv1.CreateKeyRequest) (crypto11.Signer, error) {
|
|
||||||
id, object, err := parseObject(req.Name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, err := ctx.FindKeyPair(id, object)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if signer != nil {
|
|
||||||
return nil, apiv1.ErrAlreadyExists{
|
|
||||||
Message: req.Name + " already exists",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enforce the use of both id and labels. This is not strictly necessary in
|
|
||||||
// PKCS #11, but it's a good practice.
|
|
||||||
if len(id) == 0 || len(object) == 0 {
|
|
||||||
return nil, errors.Errorf("key with uri %s is not valid, id and object are required", req.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create template for public and private keys
|
|
||||||
public, err := crypto11.NewAttributeSetWithIDAndLabel(id, object)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
private := public.Copy()
|
|
||||||
if req.Extractable {
|
|
||||||
private.Set(crypto11.CkaExtractable, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
bits := req.Bits
|
|
||||||
if bits == 0 {
|
|
||||||
bits = DefaultRSASize
|
|
||||||
}
|
|
||||||
|
|
||||||
switch req.SignatureAlgorithm {
|
|
||||||
case apiv1.UnspecifiedSignAlgorithm:
|
|
||||||
return ctx.GenerateECDSAKeyPairWithAttributes(public, private, elliptic.P256())
|
|
||||||
case apiv1.SHA256WithRSA, apiv1.SHA384WithRSA, apiv1.SHA512WithRSA:
|
|
||||||
return ctx.GenerateRSAKeyPairWithAttributes(public, private, bits)
|
|
||||||
case apiv1.SHA256WithRSAPSS, apiv1.SHA384WithRSAPSS, apiv1.SHA512WithRSAPSS:
|
|
||||||
return ctx.GenerateRSAKeyPairWithAttributes(public, private, bits)
|
|
||||||
case apiv1.ECDSAWithSHA256:
|
|
||||||
return ctx.GenerateECDSAKeyPairWithAttributes(public, private, elliptic.P256())
|
|
||||||
case apiv1.ECDSAWithSHA384:
|
|
||||||
return ctx.GenerateECDSAKeyPairWithAttributes(public, private, elliptic.P384())
|
|
||||||
case apiv1.ECDSAWithSHA512:
|
|
||||||
return ctx.GenerateECDSAKeyPairWithAttributes(public, private, elliptic.P521())
|
|
||||||
case apiv1.PureEd25519:
|
|
||||||
return nil, fmt.Errorf("signature algorithm %s is not supported", req.SignatureAlgorithm)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("signature algorithm %s is not supported", req.SignatureAlgorithm)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func findSigner(ctx P11, rawuri string) (crypto11.Signer, error) {
|
|
||||||
id, object, err := parseObject(rawuri)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
signer, err := ctx.FindKeyPair(id, object)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "error finding key with uri %s", rawuri)
|
|
||||||
}
|
|
||||||
if signer == nil {
|
|
||||||
return nil, errors.Errorf("key with uri %s not found", rawuri)
|
|
||||||
}
|
|
||||||
return signer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func findCertificate(ctx P11, rawuri string) (*x509.Certificate, error) {
|
|
||||||
u, err := uri.ParseWithScheme(Scheme, rawuri)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
id, object, serial := u.GetEncoded("id"), u.Get("object"), u.Get("serial")
|
|
||||||
if len(id) == 0 && object == "" && serial == "" {
|
|
||||||
return nil, errors.Errorf("key with uri %s is not valid, id, object or serial are required", rawuri)
|
|
||||||
}
|
|
||||||
|
|
||||||
var serialNumber *big.Int
|
|
||||||
if serial != "" {
|
|
||||||
b, err := hex.DecodeString(serial)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Errorf("key with uri %s is not valid, failed to decode serial", rawuri)
|
|
||||||
}
|
|
||||||
serialNumber = new(big.Int).SetBytes(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := ctx.FindCertificate(id, toByte(object), serialNumber)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "error finding certificate with uri %s", rawuri)
|
|
||||||
}
|
|
||||||
if cert == nil {
|
|
||||||
return nil, errors.Errorf("certificate with uri %s not found", rawuri)
|
|
||||||
}
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
|
@ -1,58 +0,0 @@
|
||||||
//go:build !cgo
|
|
||||||
// +build !cgo
|
|
||||||
|
|
||||||
package pkcs11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
)
|
|
||||||
|
|
||||||
var errUnsupported error
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
name := filepath.Base(os.Args[0])
|
|
||||||
errUnsupported = errors.Errorf("unsupported kms type 'pkcs11': %s is compiled without cgo support", name)
|
|
||||||
|
|
||||||
apiv1.Register(apiv1.PKCS11, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
|
||||||
return nil, errUnsupported
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// PKCS11 is the implementation of a KMS using the PKCS #11 standard.
|
|
||||||
type PKCS11 struct{}
|
|
||||||
|
|
||||||
// New implements the kms.KeyManager interface and without CGO will always
|
|
||||||
// return an error.
|
|
||||||
func New(ctx context.Context, opts apiv1.Options) (*PKCS11, error) {
|
|
||||||
return nil, errUnsupported
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPublicKey implements the kms.KeyManager interface and without CGO will always
|
|
||||||
// return an error.
|
|
||||||
func (*PKCS11) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
|
||||||
return nil, errUnsupported
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKey implements the kms.KeyManager interface and without CGO will always
|
|
||||||
// return an error.
|
|
||||||
func (*PKCS11) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
|
||||||
return nil, errUnsupported
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSigner implements the kms.KeyManager interface and without CGO will always
|
|
||||||
// return an error.
|
|
||||||
func (*PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
|
||||||
return nil, errUnsupported
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close implements the kms.KeyManager interface and without CGO will always
|
|
||||||
// return an error.
|
|
||||||
func (*PKCS11) Close() error {
|
|
||||||
return errUnsupported
|
|
||||||
}
|
|
|
@ -1,836 +0,0 @@
|
||||||
//go:build cgo
|
|
||||||
// +build cgo
|
|
||||||
|
|
||||||
package pkcs11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"math/big"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ThalesIgnite/crypto11"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"golang.org/x/crypto/cryptobyte"
|
|
||||||
"golang.org/x/crypto/cryptobyte/asn1"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
tmp := p11Configure
|
|
||||||
t.Cleanup(func() {
|
|
||||||
p11Configure = tmp
|
|
||||||
})
|
|
||||||
|
|
||||||
k := mustPKCS11(t)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
k.Close()
|
|
||||||
})
|
|
||||||
p11Configure = func(config *crypto11.Config) (P11, error) {
|
|
||||||
if strings.Contains(config.Path, "fail") {
|
|
||||||
return nil, errors.New("an error")
|
|
||||||
}
|
|
||||||
return k.p11, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
opts apiv1.Options
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *PKCS11
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test?pin-value=password",
|
|
||||||
}}, k, false},
|
|
||||||
{"ok with serial", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;serial=0123456789?pin-value=password",
|
|
||||||
}}, k, false},
|
|
||||||
{"ok with slot-id", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;slot-id=0?pin-value=password",
|
|
||||||
}}, k, false},
|
|
||||||
{"ok with pin", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test",
|
|
||||||
Pin: "passowrd",
|
|
||||||
}}, k, false},
|
|
||||||
{"fail missing module", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:token=pkcs11-test",
|
|
||||||
Pin: "passowrd",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail missing pin", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail missing token/serial/slot-id", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so",
|
|
||||||
Pin: "passowrd",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail token+serial+slot-id", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;serial=0123456789;slot-id=0",
|
|
||||||
Pin: "passowrd",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail token+serial", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;serial=0123456789",
|
|
||||||
Pin: "passowrd",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail token+slot-id", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;slot-id=0",
|
|
||||||
Pin: "passowrd",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail serial+slot-id", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;serial=0123456789;slot-id=0",
|
|
||||||
Pin: "passowrd",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail slot-id", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;slot-id=x?pin-value=password",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail scheme", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "foo:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test?pin-value=password",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail configure", args{context.Background(), apiv1.Options{
|
|
||||||
Type: "pkcs11",
|
|
||||||
URI: "pkcs11:module-path=/usr/local/lib/fail.so;token=pkcs11-test?pin-value=password",
|
|
||||||
}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := New(tt.args.ctx, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("New() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPKCS11_GetPublicKey(t *testing.T) {
|
|
||||||
k := setupPKCS11(t)
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.GetPublicKeyRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want crypto.PublicKey
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"RSA", args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "pkcs11:id=7371;object=rsa-key",
|
|
||||||
}}, &rsa.PublicKey{}, false},
|
|
||||||
{"RSA by id", args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "pkcs11:id=7371",
|
|
||||||
}}, &rsa.PublicKey{}, false},
|
|
||||||
{"RSA by label", args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "pkcs11:object=rsa-key",
|
|
||||||
}}, &rsa.PublicKey{}, false},
|
|
||||||
{"ECDSA", args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "pkcs11:id=7373;object=ecdsa-p256-key",
|
|
||||||
}}, &ecdsa.PublicKey{}, false},
|
|
||||||
{"ECDSA by id", args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "pkcs11:id=7373",
|
|
||||||
}}, &ecdsa.PublicKey{}, false},
|
|
||||||
{"ECDSA by label", args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "pkcs11:object=ecdsa-p256-key",
|
|
||||||
}}, &ecdsa.PublicKey{}, false},
|
|
||||||
{"fail name", args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail uri", args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "https:id=9999;object=https",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail missing", args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "pkcs11:id=9999;object=rsa-key",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail FindKeyPair", args{&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: "pkcs11:foo=bar",
|
|
||||||
}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := k.GetPublicKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PKCS11.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if reflect.TypeOf(got) != reflect.TypeOf(tt.want) {
|
|
||||||
t.Errorf("PKCS11.GetPublicKey() = %T, want %T", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPKCS11_CreateKey(t *testing.T) {
|
|
||||||
k := setupPKCS11(t)
|
|
||||||
|
|
||||||
// Make sure to delete the created key
|
|
||||||
k.DeleteKey(testObject)
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateKeyRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *apiv1.CreateKeyResponse
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"default", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &ecdsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"default extractable", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
Extractable: true,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &ecdsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"RSA SHA256WithRSA", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &rsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"RSA SHA384WithRSA", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
SignatureAlgorithm: apiv1.SHA384WithRSA,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &rsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"RSA SHA512WithRSA", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
SignatureAlgorithm: apiv1.SHA512WithRSA,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &rsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"RSA SHA256WithRSAPSS", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &rsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"RSA SHA384WithRSAPSS", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
SignatureAlgorithm: apiv1.SHA384WithRSAPSS,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &rsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"RSA SHA512WithRSAPSS", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
SignatureAlgorithm: apiv1.SHA512WithRSAPSS,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &rsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"RSA 2048", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
|
||||||
Bits: 2048,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &rsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"RSA 4096", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
|
||||||
Bits: 4096,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &rsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ECDSA P256", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &ecdsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ECDSA P384", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA384,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &ecdsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"ECDSA P521", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA512,
|
|
||||||
}}, &apiv1.CreateKeyResponse{
|
|
||||||
Name: testObject,
|
|
||||||
PublicKey: &ecdsa.PublicKey{},
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: testObject,
|
|
||||||
},
|
|
||||||
}, false},
|
|
||||||
{"fail name", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail no id", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "pkcs11:object=create-key",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail no object", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "pkcs11:id=9999",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail schema", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "pkcs12:id=9999;object=create-key",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail bits", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "pkcs11:id=9999;object=create-key",
|
|
||||||
Bits: -1,
|
|
||||||
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail ed25519", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "pkcs11:id=9999;object=create-key",
|
|
||||||
SignatureAlgorithm: apiv1.PureEd25519,
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail unknown", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "pkcs11:id=9999;object=create-key",
|
|
||||||
SignatureAlgorithm: apiv1.SignatureAlgorithm(100),
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail FindKeyPair", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "pkcs11:foo=bar",
|
|
||||||
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail already exists", args{&apiv1.CreateKeyRequest{
|
|
||||||
Name: "pkcs11:id=7373;object=ecdsa-p256-key",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := k.CreateKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PKCS11.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != nil {
|
|
||||||
got.PublicKey = tt.want.PublicKey
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("PKCS11.CreateKey() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
if got != nil {
|
|
||||||
if err := k.DeleteKey(got.Name); err != nil {
|
|
||||||
t.Errorf("PKCS11.DeleteKey() error = %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPKCS11_CreateSigner(t *testing.T) {
|
|
||||||
k := setupPKCS11(t)
|
|
||||||
data := []byte("buggy-coheir-RUBRIC-rabbet-liberal-eaglet-khartoum-stagger")
|
|
||||||
|
|
||||||
// VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the
|
|
||||||
// public key, pub. Its return value records whether the signature is valid.
|
|
||||||
verifyASN1 := func(pub *ecdsa.PublicKey, hash, sig []byte) bool {
|
|
||||||
var (
|
|
||||||
r, s = &big.Int{}, &big.Int{}
|
|
||||||
inner cryptobyte.String
|
|
||||||
)
|
|
||||||
input := cryptobyte.String(sig)
|
|
||||||
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
|
|
||||||
!input.Empty() ||
|
|
||||||
!inner.ReadASN1Integer(r) ||
|
|
||||||
!inner.ReadASN1Integer(s) ||
|
|
||||||
!inner.Empty() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return ecdsa.Verify(pub, hash, r, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateSignerRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
algorithm apiv1.SignatureAlgorithm
|
|
||||||
signerOpts crypto.SignerOpts
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
// SoftHSM2
|
|
||||||
{"RSA", args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "pkcs11:id=7371;object=rsa-key",
|
|
||||||
}}, apiv1.SHA256WithRSA, crypto.SHA256, false},
|
|
||||||
{"RSA PSS", args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "pkcs11:id=7372;object=rsa-pss-key",
|
|
||||||
}}, apiv1.SHA256WithRSAPSS, &rsa.PSSOptions{
|
|
||||||
SaltLength: rsa.PSSSaltLengthEqualsHash,
|
|
||||||
Hash: crypto.SHA256,
|
|
||||||
}, false},
|
|
||||||
{"ECDSA P256", args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "pkcs11:id=7373;object=ecdsa-p256-key",
|
|
||||||
}}, apiv1.ECDSAWithSHA256, crypto.SHA256, false},
|
|
||||||
{"ECDSA P384", args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "pkcs11:id=7374;object=ecdsa-p384-key",
|
|
||||||
}}, apiv1.ECDSAWithSHA384, crypto.SHA384, false},
|
|
||||||
{"ECDSA P521", args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "pkcs11:id=7375;object=ecdsa-p521-key",
|
|
||||||
}}, apiv1.ECDSAWithSHA512, crypto.SHA512, false},
|
|
||||||
{"fail SigningKey", args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "",
|
|
||||||
}}, 0, nil, true},
|
|
||||||
{"fail uri", args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "https:id=7375;object=ecdsa-p521-key",
|
|
||||||
}}, 0, nil, true},
|
|
||||||
{"fail FindKeyPair", args{&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: "pkcs11:foo=bar",
|
|
||||||
}}, 0, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := k.CreateSigner(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PKCS11.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if got != nil {
|
|
||||||
hash := tt.signerOpts.HashFunc()
|
|
||||||
h := hash.New()
|
|
||||||
h.Write(data)
|
|
||||||
digest := h.Sum(nil)
|
|
||||||
sig, err := got.Sign(rand.Reader, digest, tt.signerOpts)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("cyrpto.Signer.Sign() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch tt.algorithm {
|
|
||||||
case apiv1.SHA256WithRSA, apiv1.SHA384WithRSA, apiv1.SHA512WithRSA:
|
|
||||||
pub := got.Public().(*rsa.PublicKey)
|
|
||||||
if err := rsa.VerifyPKCS1v15(pub, hash, digest, sig); err != nil {
|
|
||||||
t.Errorf("rsa.VerifyPKCS1v15() error = %v", err)
|
|
||||||
}
|
|
||||||
case apiv1.UnspecifiedSignAlgorithm, apiv1.SHA256WithRSAPSS, apiv1.SHA384WithRSAPSS, apiv1.SHA512WithRSAPSS:
|
|
||||||
pub := got.Public().(*rsa.PublicKey)
|
|
||||||
if err := rsa.VerifyPSS(pub, hash, digest, sig, tt.signerOpts.(*rsa.PSSOptions)); err != nil {
|
|
||||||
t.Errorf("rsa.VerifyPSS() error = %v", err)
|
|
||||||
}
|
|
||||||
case apiv1.ECDSAWithSHA256, apiv1.ECDSAWithSHA384, apiv1.ECDSAWithSHA512:
|
|
||||||
pub := got.Public().(*ecdsa.PublicKey)
|
|
||||||
if !verifyASN1(pub, digest, sig) {
|
|
||||||
t.Error("ecdsa.VerifyASN1() failed")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
t.Errorf("signature algorithm %s is not supported", tt.algorithm)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPKCS11_CreateDecrypter(t *testing.T) {
|
|
||||||
k := setupPKCS11(t)
|
|
||||||
data := []byte("buggy-coheir-RUBRIC-rabbet-liberal-eaglet-khartoum-stagger")
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateDecrypterRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"RSA", args{&apiv1.CreateDecrypterRequest{
|
|
||||||
DecryptionKey: "pkcs11:id=7371;object=rsa-key",
|
|
||||||
}}, false},
|
|
||||||
{"RSA PSS", args{&apiv1.CreateDecrypterRequest{
|
|
||||||
DecryptionKey: "pkcs11:id=7372;object=rsa-pss-key",
|
|
||||||
}}, false},
|
|
||||||
{"ECDSA P256", args{&apiv1.CreateDecrypterRequest{
|
|
||||||
DecryptionKey: "pkcs11:id=7373;object=ecdsa-p256-key",
|
|
||||||
}}, true},
|
|
||||||
{"ECDSA P384", args{&apiv1.CreateDecrypterRequest{
|
|
||||||
DecryptionKey: "pkcs11:id=7374;object=ecdsa-p384-key",
|
|
||||||
}}, true},
|
|
||||||
{"ECDSA P521", args{&apiv1.CreateDecrypterRequest{
|
|
||||||
DecryptionKey: "pkcs11:id=7375;object=ecdsa-p521-key",
|
|
||||||
}}, true},
|
|
||||||
{"fail DecryptionKey", args{&apiv1.CreateDecrypterRequest{
|
|
||||||
DecryptionKey: "",
|
|
||||||
}}, true},
|
|
||||||
{"fail uri", args{&apiv1.CreateDecrypterRequest{
|
|
||||||
DecryptionKey: "https:id=7375;object=ecdsa-p521-key",
|
|
||||||
}}, true},
|
|
||||||
{"fail FindKeyPair", args{&apiv1.CreateDecrypterRequest{
|
|
||||||
DecryptionKey: "pkcs11:foo=bar",
|
|
||||||
}}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := k.CreateDecrypter(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PKCS11.CreateDecrypter() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if got != nil {
|
|
||||||
pub := got.Public().(*rsa.PublicKey)
|
|
||||||
// PKCS#1 v1.5
|
|
||||||
enc, err := rsa.EncryptPKCS1v15(rand.Reader, pub, data)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("rsa.EncryptPKCS1v15() error = %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dec, err := got.Decrypt(rand.Reader, enc, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("PKCS1v15.Decrypt() error = %v", err)
|
|
||||||
} else if !bytes.Equal(dec, data) {
|
|
||||||
t.Errorf("PKCS1v15.Decrypt() failed got = %s, want = %s", dec, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RSA-OAEP
|
|
||||||
enc, err = rsa.EncryptOAEP(crypto.SHA256.New(), rand.Reader, pub, data, []byte("label"))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("rsa.EncryptOAEP() error = %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dec, err = got.Decrypt(rand.Reader, enc, &rsa.OAEPOptions{
|
|
||||||
Hash: crypto.SHA256,
|
|
||||||
Label: []byte("label"),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("RSA-OAEP.Decrypt() error = %v", err)
|
|
||||||
} else if !bytes.Equal(dec, data) {
|
|
||||||
t.Errorf("RSA-OAEP.Decrypt() RSA-OAEP failed got = %s, want = %s", dec, data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPKCS11_LoadCertificate(t *testing.T) {
|
|
||||||
k := setupPKCS11(t)
|
|
||||||
|
|
||||||
getCertFn := func(i, j int) func() *x509.Certificate {
|
|
||||||
return func() *x509.Certificate {
|
|
||||||
return testCerts[i].Certificates[j]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.LoadCertificateRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantFn func() *x509.Certificate
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"load", args{&apiv1.LoadCertificateRequest{
|
|
||||||
Name: "pkcs11:id=7376;object=test-root",
|
|
||||||
}}, getCertFn(0, 0), false},
|
|
||||||
{"load by id", args{&apiv1.LoadCertificateRequest{
|
|
||||||
Name: "pkcs11:id=7376",
|
|
||||||
}}, getCertFn(0, 0), false},
|
|
||||||
{"load by label", args{&apiv1.LoadCertificateRequest{
|
|
||||||
Name: "pkcs11:object=test-root",
|
|
||||||
}}, getCertFn(0, 0), false},
|
|
||||||
{"load by serial", args{&apiv1.LoadCertificateRequest{
|
|
||||||
Name: "pkcs11:serial=64",
|
|
||||||
}}, getCertFn(0, 0), false},
|
|
||||||
{"fail missing", args{&apiv1.LoadCertificateRequest{
|
|
||||||
Name: "pkcs11:id=9999;object=test-root",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail name", args{&apiv1.LoadCertificateRequest{
|
|
||||||
Name: "",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail scheme", args{&apiv1.LoadCertificateRequest{
|
|
||||||
Name: "foo:id=7376;object=test-root",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail serial", args{&apiv1.LoadCertificateRequest{
|
|
||||||
Name: "pkcs11:serial=foo",
|
|
||||||
}}, nil, true},
|
|
||||||
{"fail FindCertificate", args{&apiv1.LoadCertificateRequest{
|
|
||||||
Name: "pkcs11:foo=bar",
|
|
||||||
}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := k.LoadCertificate(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PKCS11.LoadCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var want *x509.Certificate
|
|
||||||
if tt.wantFn != nil {
|
|
||||||
want = tt.wantFn()
|
|
||||||
got.Raw, got.RawIssuer, got.RawSubject, got.RawTBSCertificate, got.RawSubjectPublicKeyInfo = nil, nil, nil, nil, nil
|
|
||||||
want.Raw, want.RawIssuer, want.RawSubject, want.RawTBSCertificate, want.RawSubjectPublicKeyInfo = nil, nil, nil, nil, nil
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, want) {
|
|
||||||
t.Errorf("PKCS11.LoadCertificate() = %v, want %v", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPKCS11_StoreCertificate(t *testing.T) {
|
|
||||||
k := setupPKCS11(t)
|
|
||||||
|
|
||||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ed25519.GenerateKey() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := generateCertificate(pub, priv)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("x509.CreateCertificate() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure to delete the created certificate
|
|
||||||
t.Cleanup(func() {
|
|
||||||
k.DeleteCertificate(testObject)
|
|
||||||
k.DeleteCertificate(testObjectAlt)
|
|
||||||
})
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.StoreCertificateRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{&apiv1.StoreCertificateRequest{
|
|
||||||
Name: testObject,
|
|
||||||
Certificate: cert,
|
|
||||||
}}, false},
|
|
||||||
{"ok extractable", args{&apiv1.StoreCertificateRequest{
|
|
||||||
Name: testObjectAlt,
|
|
||||||
Certificate: cert,
|
|
||||||
Extractable: true,
|
|
||||||
}}, false},
|
|
||||||
{"fail already exists", args{&apiv1.StoreCertificateRequest{
|
|
||||||
Name: testObject,
|
|
||||||
Certificate: cert,
|
|
||||||
}}, true},
|
|
||||||
{"fail name", args{&apiv1.StoreCertificateRequest{
|
|
||||||
Name: "",
|
|
||||||
Certificate: cert,
|
|
||||||
}}, true},
|
|
||||||
{"fail certificate", args{&apiv1.StoreCertificateRequest{
|
|
||||||
Name: testObject,
|
|
||||||
Certificate: nil,
|
|
||||||
}}, true},
|
|
||||||
{"fail uri", args{&apiv1.StoreCertificateRequest{
|
|
||||||
Name: "http:id=7770;object=create-cert",
|
|
||||||
Certificate: cert,
|
|
||||||
}}, true},
|
|
||||||
{"fail missing id", args{&apiv1.StoreCertificateRequest{
|
|
||||||
Name: "pkcs11:object=create-cert",
|
|
||||||
Certificate: cert,
|
|
||||||
}}, true},
|
|
||||||
{"fail missing object", args{&apiv1.StoreCertificateRequest{
|
|
||||||
Name: "pkcs11:id=7770;object=",
|
|
||||||
Certificate: cert,
|
|
||||||
}}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.args.req.Extractable {
|
|
||||||
if testModule == "SoftHSM2" {
|
|
||||||
t.Skip("Extractable certificates are not supported on SoftHSM2")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := k.StoreCertificate(tt.args.req); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PKCS11.StoreCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
if !tt.wantErr {
|
|
||||||
got, err := k.LoadCertificate(&apiv1.LoadCertificateRequest{
|
|
||||||
Name: tt.args.req.Name,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("PKCS11.LoadCertificate() error = %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, cert) {
|
|
||||||
t.Errorf("PKCS11.LoadCertificate() = %v, want %v", got, cert)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPKCS11_DeleteKey(t *testing.T) {
|
|
||||||
k := setupPKCS11(t)
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
uri string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"delete", args{testObject}, false},
|
|
||||||
{"delete by id", args{testObjectByID}, false},
|
|
||||||
{"delete by label", args{testObjectByLabel}, false},
|
|
||||||
{"delete missing", args{"pkcs11:id=9999;object=missing-key"}, false},
|
|
||||||
{"fail name", args{""}, true},
|
|
||||||
{"fail FindKeyPair", args{"pkcs11:foo=bar"}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if _, err := k.CreateKey(&apiv1.CreateKeyRequest{
|
|
||||||
Name: testObject,
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("PKCS1.CreateKey() error = %v", err)
|
|
||||||
}
|
|
||||||
if err := k.DeleteKey(tt.args.uri); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PKCS11.DeleteKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
if _, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
|
|
||||||
Name: tt.args.uri,
|
|
||||||
}); err == nil {
|
|
||||||
t.Error("PKCS11.GetPublicKey() public key found and not expected")
|
|
||||||
}
|
|
||||||
// Make sure to delete the created one.
|
|
||||||
if err := k.DeleteKey(testObject); err != nil {
|
|
||||||
t.Errorf("PKCS11.DeleteKey() error = %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPKCS11_DeleteCertificate(t *testing.T) {
|
|
||||||
k := setupPKCS11(t)
|
|
||||||
|
|
||||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ed25519.GenerateKey() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := generateCertificate(pub, priv)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("x509.CreateCertificate() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
uri string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"delete", args{testObject}, false},
|
|
||||||
{"delete by id", args{testObjectByID}, false},
|
|
||||||
{"delete by label", args{testObjectByLabel}, false},
|
|
||||||
{"delete missing", args{"pkcs11:id=9999;object=missing-key"}, false},
|
|
||||||
{"fail name", args{""}, true},
|
|
||||||
{"fail DeleteCertificate", args{"pkcs11:foo=bar"}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if err := k.StoreCertificate(&apiv1.StoreCertificateRequest{
|
|
||||||
Name: testObject,
|
|
||||||
Certificate: cert,
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("PKCS11.StoreCertificate() error = %v", err)
|
|
||||||
}
|
|
||||||
if err := k.DeleteCertificate(tt.args.uri); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PKCS11.DeleteCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
if _, err := k.LoadCertificate(&apiv1.LoadCertificateRequest{
|
|
||||||
Name: tt.args.uri,
|
|
||||||
}); err == nil {
|
|
||||||
t.Error("PKCS11.LoadCertificate() certificate found and not expected")
|
|
||||||
}
|
|
||||||
// Make sure to delete the created one.
|
|
||||||
if err := k.DeleteCertificate(testObject); err != nil {
|
|
||||||
t.Errorf("PKCS11.DeleteCertificate() error = %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPKCS11_Close(t *testing.T) {
|
|
||||||
k := mustPKCS11(t)
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", false},
|
|
||||||
{"second", false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if err := k.Close(); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PKCS11.Close() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,145 +0,0 @@
|
||||||
//go:build cgo
|
|
||||||
// +build cgo
|
|
||||||
|
|
||||||
package pkcs11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"math/big"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
testModule = ""
|
|
||||||
testObject = "pkcs11:id=7370;object=test-name"
|
|
||||||
testObjectAlt = "pkcs11:id=7377;object=alt-test-name"
|
|
||||||
testObjectByID = "pkcs11:id=7370"
|
|
||||||
testObjectByLabel = "pkcs11:object=test-name"
|
|
||||||
testKeys = []struct {
|
|
||||||
Name string
|
|
||||||
SignatureAlgorithm apiv1.SignatureAlgorithm
|
|
||||||
Bits int
|
|
||||||
}{
|
|
||||||
{"pkcs11:id=7371;object=rsa-key", apiv1.SHA256WithRSA, 2048},
|
|
||||||
{"pkcs11:id=7372;object=rsa-pss-key", apiv1.SHA256WithRSAPSS, DefaultRSASize},
|
|
||||||
{"pkcs11:id=7373;object=ecdsa-p256-key", apiv1.ECDSAWithSHA256, 0},
|
|
||||||
{"pkcs11:id=7374;object=ecdsa-p384-key", apiv1.ECDSAWithSHA384, 0},
|
|
||||||
{"pkcs11:id=7375;object=ecdsa-p521-key", apiv1.ECDSAWithSHA512, 0},
|
|
||||||
}
|
|
||||||
|
|
||||||
testCerts = []struct {
|
|
||||||
Name string
|
|
||||||
Key string
|
|
||||||
Certificates []*x509.Certificate
|
|
||||||
}{
|
|
||||||
{"pkcs11:id=7376;object=test-root", "pkcs11:id=7373;object=ecdsa-p256-key", nil},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
type TBTesting interface {
|
|
||||||
Helper()
|
|
||||||
Cleanup(f func())
|
|
||||||
Log(args ...interface{})
|
|
||||||
Errorf(format string, args ...interface{})
|
|
||||||
Fatalf(format string, args ...interface{})
|
|
||||||
Skipf(format string, args ...interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateCertificate(pub crypto.PublicKey, signer crypto.Signer) (*x509.Certificate, error) {
|
|
||||||
now := time.Now()
|
|
||||||
template := &x509.Certificate{
|
|
||||||
Subject: pkix.Name{CommonName: "Test Root Certificate"},
|
|
||||||
Issuer: pkix.Name{CommonName: "Test Root Certificate"},
|
|
||||||
IsCA: true,
|
|
||||||
MaxPathLen: 1,
|
|
||||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
|
||||||
NotBefore: now,
|
|
||||||
NotAfter: now.Add(time.Hour),
|
|
||||||
SerialNumber: big.NewInt(100),
|
|
||||||
}
|
|
||||||
|
|
||||||
b, err := x509.CreateCertificate(rand.Reader, template, template, pub, signer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return x509.ParseCertificate(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setup(t TBTesting, k *PKCS11) {
|
|
||||||
t.Log("Running using", testModule)
|
|
||||||
for _, tk := range testKeys {
|
|
||||||
_, err := k.CreateKey(&apiv1.CreateKeyRequest{
|
|
||||||
Name: tk.Name,
|
|
||||||
SignatureAlgorithm: tk.SignatureAlgorithm,
|
|
||||||
Bits: tk.Bits,
|
|
||||||
})
|
|
||||||
if err != nil && !errors.Is(errors.Cause(err), apiv1.ErrAlreadyExists{
|
|
||||||
Message: tk.Name + " already exists",
|
|
||||||
}) {
|
|
||||||
t.Errorf("PKCS11.GetPublicKey() error = %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, c := range testCerts {
|
|
||||||
signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: c.Key,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("PKCS11.CreateSigner() error = %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
cert, err := generateCertificate(signer.Public(), signer)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("x509.CreateCertificate() error = %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := k.StoreCertificate(&apiv1.StoreCertificateRequest{
|
|
||||||
Name: c.Name,
|
|
||||||
Certificate: cert,
|
|
||||||
}); err != nil && !errors.Is(errors.Cause(err), apiv1.ErrAlreadyExists{
|
|
||||||
Message: c.Name + " already exists",
|
|
||||||
}) {
|
|
||||||
t.Errorf("PKCS1.StoreCertificate() error = %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
testCerts[i].Certificates = append(testCerts[i].Certificates, cert)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func teardown(t TBTesting, k *PKCS11) {
|
|
||||||
testObjects := []string{testObject, testObjectByID, testObjectByLabel}
|
|
||||||
for _, name := range testObjects {
|
|
||||||
if err := k.DeleteKey(name); err != nil {
|
|
||||||
t.Errorf("PKCS11.DeleteKey() error = %v", err)
|
|
||||||
}
|
|
||||||
if err := k.DeleteCertificate(name); err != nil {
|
|
||||||
t.Errorf("PKCS11.DeleteCertificate() error = %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, tk := range testKeys {
|
|
||||||
if err := k.DeleteKey(tk.Name); err != nil {
|
|
||||||
t.Errorf("PKCS11.DeleteKey() error = %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, tc := range testCerts {
|
|
||||||
if err := k.DeleteCertificate(tc.Name); err != nil {
|
|
||||||
t.Errorf("PKCS11.DeleteCertificate() error = %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupPKCS11(t TBTesting) *PKCS11 {
|
|
||||||
t.Helper()
|
|
||||||
k := mustPKCS11(t)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
k.Close()
|
|
||||||
})
|
|
||||||
return k
|
|
||||||
}
|
|
|
@ -1,62 +0,0 @@
|
||||||
//go:build cgo && softhsm2
|
|
||||||
// +build cgo,softhsm2
|
|
||||||
|
|
||||||
package pkcs11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/ThalesIgnite/crypto11"
|
|
||||||
)
|
|
||||||
|
|
||||||
var softHSM2Once sync.Once
|
|
||||||
|
|
||||||
// mustPKCS11 configures a *PKCS11 KMS to be used with SoftHSM2. To initialize
|
|
||||||
// these tests, we should run:
|
|
||||||
//
|
|
||||||
// softhsm2-util --init-token --free \
|
|
||||||
// --token pkcs11-test --label pkcs11-test \
|
|
||||||
// --so-pin password --pin password
|
|
||||||
//
|
|
||||||
// To delete we should run:
|
|
||||||
//
|
|
||||||
// softhsm2-util --delete-token --token pkcs11-test
|
|
||||||
func mustPKCS11(t TBTesting) *PKCS11 {
|
|
||||||
t.Helper()
|
|
||||||
testModule = "SoftHSM2"
|
|
||||||
if runtime.GOARCH != "amd64" {
|
|
||||||
t.Fatalf("softHSM2 test skipped on %s:%s", runtime.GOOS, runtime.GOARCH)
|
|
||||||
}
|
|
||||||
|
|
||||||
var path string
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "darwin":
|
|
||||||
path = "/usr/local/lib/softhsm/libsofthsm2.so"
|
|
||||||
case "linux":
|
|
||||||
path = "/usr/lib/softhsm/libsofthsm2.so"
|
|
||||||
default:
|
|
||||||
t.Skipf("softHSM2 test skipped on %s", runtime.GOOS)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
p11, err := crypto11.Configure(&crypto11.Config{
|
|
||||||
Path: path,
|
|
||||||
TokenLabel: "pkcs11-test",
|
|
||||||
Pin: "password",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to configure softHSM2 on %s: %v", runtime.GOOS, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
k := &PKCS11{
|
|
||||||
p11: p11,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup
|
|
||||||
softHSM2Once.Do(func() {
|
|
||||||
teardown(t, k)
|
|
||||||
setup(t, k)
|
|
||||||
})
|
|
||||||
|
|
||||||
return k
|
|
||||||
}
|
|
|
@ -1,56 +0,0 @@
|
||||||
//go:build cgo && yubihsm2
|
|
||||||
// +build cgo,yubihsm2
|
|
||||||
|
|
||||||
package pkcs11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/ThalesIgnite/crypto11"
|
|
||||||
)
|
|
||||||
|
|
||||||
var yubiHSM2Once sync.Once
|
|
||||||
|
|
||||||
// mustPKCS11 configures a *PKCS11 KMS to be used with YubiHSM2. To initialize
|
|
||||||
// these tests, we should run:
|
|
||||||
//
|
|
||||||
// yubihsm-connector -d
|
|
||||||
func mustPKCS11(t TBTesting) *PKCS11 {
|
|
||||||
t.Helper()
|
|
||||||
testModule = "YubiHSM2"
|
|
||||||
if runtime.GOARCH != "amd64" {
|
|
||||||
t.Skipf("yubiHSM2 test skipped on %s:%s", runtime.GOOS, runtime.GOARCH)
|
|
||||||
}
|
|
||||||
|
|
||||||
var path string
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "darwin":
|
|
||||||
path = "/usr/local/lib/pkcs11/yubihsm_pkcs11.dylib"
|
|
||||||
case "linux":
|
|
||||||
path = "/usr/lib/x86_64-linux-gnu/pkcs11/yubihsm_pkcs11.so"
|
|
||||||
default:
|
|
||||||
t.Skipf("yubiHSM2 test skipped on %s", runtime.GOOS)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
p11, err := crypto11.Configure(&crypto11.Config{
|
|
||||||
Path: path,
|
|
||||||
TokenLabel: "YubiHSM",
|
|
||||||
Pin: "0001password",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to configure YubiHSM2 on %s: %v", runtime.GOOS, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
k := &PKCS11{
|
|
||||||
p11: p11,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup
|
|
||||||
yubiHSM2Once.Do(func() {
|
|
||||||
teardown(t, k)
|
|
||||||
setup(t, k)
|
|
||||||
})
|
|
||||||
|
|
||||||
return k
|
|
||||||
}
|
|
|
@ -1,183 +0,0 @@
|
||||||
package softkms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"go.step.sm/cli-utils/ui"
|
|
||||||
"go.step.sm/crypto/keyutil"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
type algorithmAttributes struct {
|
|
||||||
Type string
|
|
||||||
Curve string
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultRSAKeySize is the default size for RSA keys.
|
|
||||||
const DefaultRSAKeySize = 3072
|
|
||||||
|
|
||||||
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]algorithmAttributes{
|
|
||||||
apiv1.UnspecifiedSignAlgorithm: {"EC", "P-256"},
|
|
||||||
apiv1.SHA256WithRSA: {"RSA", ""},
|
|
||||||
apiv1.SHA384WithRSA: {"RSA", ""},
|
|
||||||
apiv1.SHA512WithRSA: {"RSA", ""},
|
|
||||||
apiv1.SHA256WithRSAPSS: {"RSA", ""},
|
|
||||||
apiv1.SHA384WithRSAPSS: {"RSA", ""},
|
|
||||||
apiv1.SHA512WithRSAPSS: {"RSA", ""},
|
|
||||||
apiv1.ECDSAWithSHA256: {"EC", "P-256"},
|
|
||||||
apiv1.ECDSAWithSHA384: {"EC", "P-384"},
|
|
||||||
apiv1.ECDSAWithSHA512: {"EC", "P-521"},
|
|
||||||
apiv1.PureEd25519: {"OKP", "Ed25519"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateKey is used for testing purposes.
|
|
||||||
var generateKey = func(kty, crv string, size int) (interface{}, interface{}, error) {
|
|
||||||
if kty == "RSA" && size == 0 {
|
|
||||||
size = DefaultRSAKeySize
|
|
||||||
}
|
|
||||||
return keyutil.GenerateKeyPair(kty, crv, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoftKMS is a key manager that uses keys stored in disk.
|
|
||||||
type SoftKMS struct{}
|
|
||||||
|
|
||||||
// New returns a new SoftKMS.
|
|
||||||
func New(ctx context.Context, opts apiv1.Options) (*SoftKMS, error) {
|
|
||||||
return &SoftKMS{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
pemutil.PromptPassword = func(msg string) ([]byte, error) {
|
|
||||||
return ui.PromptPassword(msg)
|
|
||||||
}
|
|
||||||
apiv1.Register(apiv1.SoftKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
|
||||||
return New(ctx, opts)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close is a noop that just returns nil.
|
|
||||||
func (k *SoftKMS) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSigner returns a new signer configured with the given signing key.
|
|
||||||
func (k *SoftKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
|
||||||
var opts []pemutil.Options
|
|
||||||
if req.Password != nil {
|
|
||||||
opts = append(opts, pemutil.WithPassword(req.Password))
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case req.Signer != nil:
|
|
||||||
return req.Signer, nil
|
|
||||||
case len(req.SigningKeyPEM) != 0:
|
|
||||||
v, err := pemutil.ParseKey(req.SigningKeyPEM, opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sig, ok := v.(crypto.Signer)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("signingKeyPEM is not a crypto.Signer")
|
|
||||||
}
|
|
||||||
return sig, nil
|
|
||||||
case req.SigningKey != "":
|
|
||||||
v, err := pemutil.Read(req.SigningKey, opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sig, ok := v.(crypto.Signer)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("signingKey is not a crypto.Signer")
|
|
||||||
}
|
|
||||||
return sig, nil
|
|
||||||
default:
|
|
||||||
return nil, errors.New("failed to load softKMS: please define signingKeyPEM or signingKey")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKey generates a new key using Golang crypto and returns both public and
|
|
||||||
// private key.
|
|
||||||
func (k *SoftKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
|
||||||
v, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("softKMS does not support signature algorithm '%s'", req.SignatureAlgorithm)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub, priv, err := generateKey(v.Type, v.Curve, req.Bits)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
signer, ok := priv.(crypto.Signer)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("softKMS createKey result is not a crypto.Signer: type %T", priv)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &apiv1.CreateKeyResponse{
|
|
||||||
Name: req.Name,
|
|
||||||
PublicKey: pub,
|
|
||||||
PrivateKey: priv,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
Signer: signer,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPublicKey returns the public key from the file passed in the request name.
|
|
||||||
func (k *SoftKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
|
||||||
v, err := pemutil.Read(req.Name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch vv := v.(type) {
|
|
||||||
case *x509.Certificate:
|
|
||||||
return vv.PublicKey, nil
|
|
||||||
case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
|
|
||||||
return vv, nil
|
|
||||||
default:
|
|
||||||
return nil, errors.Errorf("unsupported public key type %T", v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateDecrypter creates a new crypto.Decrypter backed by disk/software
|
|
||||||
func (k *SoftKMS) CreateDecrypter(req *apiv1.CreateDecrypterRequest) (crypto.Decrypter, error) {
|
|
||||||
|
|
||||||
var opts []pemutil.Options
|
|
||||||
if req.Password != nil {
|
|
||||||
opts = append(opts, pemutil.WithPassword(req.Password))
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case req.Decrypter != nil:
|
|
||||||
return req.Decrypter, nil
|
|
||||||
case len(req.DecryptionKeyPEM) != 0:
|
|
||||||
v, err := pemutil.ParseKey(req.DecryptionKeyPEM, opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
decrypter, ok := v.(crypto.Decrypter)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("decryptorKeyPEM is not a crypto.Decrypter")
|
|
||||||
}
|
|
||||||
return decrypter, nil
|
|
||||||
case req.DecryptionKey != "":
|
|
||||||
v, err := pemutil.Read(req.DecryptionKey, opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
decrypter, ok := v.(crypto.Decrypter)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("decryptionKey is not a crypto.Decrypter")
|
|
||||||
}
|
|
||||||
return decrypter, nil
|
|
||||||
default:
|
|
||||||
return nil, errors.New("failed to load softKMS: please define decryptionKeyPEM or decryptionKey")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,381 +0,0 @@
|
||||||
package softkms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
opts apiv1.Options
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *SoftKMS
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{context.Background(), apiv1.Options{}}, &SoftKMS{}, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := New(tt.args.ctx, tt.args.opts)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("New() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoftKMS_Close(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &SoftKMS{}
|
|
||||||
if err := k.Close(); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("SoftKMS.Close() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoftKMS_CreateSigner(t *testing.T) {
|
|
||||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pemBlock, err := pemutil.Serialize(pk)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pemBlockPassword, err := pemutil.Serialize(pk, pemutil.WithPassword([]byte("pass")))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read and decode file using standard packages
|
|
||||||
b, err := os.ReadFile("testdata/priv.pem")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
block, _ := pem.Decode(b)
|
|
||||||
block.Bytes, err = x509.DecryptPEMBlock(block, []byte("pass")) //nolint
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pk2, err := x509.ParseECPrivateKey(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a public PEM
|
|
||||||
b, err = x509.MarshalPKIXPublicKey(pk.Public())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pub := pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "PUBLIC KEY",
|
|
||||||
Bytes: b,
|
|
||||||
})
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateSignerRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want crypto.Signer
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"signer", args{&apiv1.CreateSignerRequest{Signer: pk}}, pk, false},
|
|
||||||
{"pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlock)}}, pk, false},
|
|
||||||
{"pem password", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlockPassword), Password: []byte("pass")}}, pk, false},
|
|
||||||
{"file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("pass")}}, pk2, false},
|
|
||||||
{"fail", args{&apiv1.CreateSignerRequest{}}, nil, true},
|
|
||||||
{"fail bad pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: []byte("bad pem")}}, nil, true},
|
|
||||||
{"fail bad password", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("bad-pass")}}, nil, true},
|
|
||||||
{"fail not a signer", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pub}}, nil, true},
|
|
||||||
{"fail not a signer from file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/pub.pem"}}, nil, true},
|
|
||||||
{"fail missing", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/missing"}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &SoftKMS{}
|
|
||||||
got, err := k.CreateSigner(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("SoftKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("SoftKMS.CreateSigner() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func restoreGenerateKey() func() {
|
|
||||||
oldGenerateKey := generateKey
|
|
||||||
return func() {
|
|
||||||
generateKey = oldGenerateKey
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoftKMS_CreateKey(t *testing.T) {
|
|
||||||
fn := restoreGenerateKey()
|
|
||||||
defer fn()
|
|
||||||
|
|
||||||
p256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
edpub, edpriv, err := ed25519.GenerateKey(rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateKeyRequest
|
|
||||||
}
|
|
||||||
type params struct {
|
|
||||||
kty string
|
|
||||||
crv string
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
generateKey func() (interface{}, interface{}, error)
|
|
||||||
want *apiv1.CreateKeyResponse
|
|
||||||
wantParams params
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"p256", args{&apiv1.CreateKeyRequest{Name: "p256", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
|
|
||||||
return p256.Public(), p256, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "p256", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
|
|
||||||
{"rsa", args{&apiv1.CreateKeyRequest{Name: "rsa3072", SignatureAlgorithm: apiv1.SHA256WithRSA}}, func() (interface{}, interface{}, error) {
|
|
||||||
return rsa2048.Public(), rsa2048, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "rsa3072", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 0}, false},
|
|
||||||
{"rsa2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 2048}}, func() (interface{}, interface{}, error) {
|
|
||||||
return rsa2048.Public(), rsa2048, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
|
|
||||||
{"rsaPSS2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSAPSS, Bits: 2048}}, func() (interface{}, interface{}, error) {
|
|
||||||
return rsa2048.Public(), rsa2048, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
|
|
||||||
{"ed25519", args{&apiv1.CreateKeyRequest{Name: "ed25519", SignatureAlgorithm: apiv1.PureEd25519}}, func() (interface{}, interface{}, error) {
|
|
||||||
return edpub, edpriv, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "ed25519", PublicKey: edpub, PrivateKey: edpriv, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: edpriv}}, params{"OKP", "Ed25519", 0}, false},
|
|
||||||
{"default", args{&apiv1.CreateKeyRequest{Name: "default"}}, func() (interface{}, interface{}, error) {
|
|
||||||
return p256.Public(), p256, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "default", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
|
|
||||||
{"fail algorithm", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.SignatureAlgorithm(100)}}, func() (interface{}, interface{}, error) {
|
|
||||||
return p256.Public(), p256, nil
|
|
||||||
}, nil, params{}, true},
|
|
||||||
{"fail generate key", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
|
|
||||||
return nil, nil, fmt.Errorf("an error")
|
|
||||||
}, nil, params{"EC", "P-256", 0}, true},
|
|
||||||
{"fail no signer", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
|
|
||||||
return 1, 2, nil
|
|
||||||
}, nil, params{"EC", "P-256", 0}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &SoftKMS{}
|
|
||||||
generateKey = func(kty, crv string, size int) (interface{}, interface{}, error) {
|
|
||||||
if tt.wantParams.kty != kty {
|
|
||||||
t.Errorf("GenerateKey() kty = %s, want %s", kty, tt.wantParams.kty)
|
|
||||||
}
|
|
||||||
if tt.wantParams.crv != crv {
|
|
||||||
t.Errorf("GenerateKey() crv = %s, want %s", crv, tt.wantParams.crv)
|
|
||||||
}
|
|
||||||
if tt.wantParams.size != size {
|
|
||||||
t.Errorf("GenerateKey() size = %d, want %d", size, tt.wantParams.size)
|
|
||||||
}
|
|
||||||
return tt.generateKey()
|
|
||||||
}
|
|
||||||
|
|
||||||
got, err := k.CreateKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("SoftKMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("SoftKMS.CreateKey() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoftKMS_GetPublicKey(t *testing.T) {
|
|
||||||
b, err := os.ReadFile("testdata/pub.pem")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
block, _ := pem.Decode(b)
|
|
||||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.GetPublicKeyRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want crypto.PublicKey
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"key", args{&apiv1.GetPublicKeyRequest{Name: "testdata/pub.pem"}}, pub, false},
|
|
||||||
{"cert", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.crt"}}, pub, false},
|
|
||||||
{"fail not exists", args{&apiv1.GetPublicKeyRequest{Name: "testdata/missing"}}, nil, true},
|
|
||||||
{"fail type", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.key"}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &SoftKMS{}
|
|
||||||
got, err := k.GetPublicKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("SoftKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("SoftKMS.GetPublicKey() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_generateKey(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
kty string
|
|
||||||
crv string
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantType interface{}
|
|
||||||
wantType1 interface{}
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"rsa2048", args{"RSA", "", 0}, &rsa.PublicKey{}, &rsa.PrivateKey{}, false},
|
|
||||||
{"rsa2048", args{"RSA", "", 2048}, &rsa.PublicKey{}, &rsa.PrivateKey{}, false},
|
|
||||||
{"p256", args{"EC", "P-256", 0}, &ecdsa.PublicKey{}, &ecdsa.PrivateKey{}, false},
|
|
||||||
{"ed25519", args{"OKP", "Ed25519", 0}, ed25519.PublicKey{}, ed25519.PrivateKey{}, false},
|
|
||||||
{"fail kty", args{"FOO", "", 0}, nil, nil, true},
|
|
||||||
{"fail crv", args{"EC", "P-123", 0}, nil, nil, true},
|
|
||||||
{"fail size", args{"RSA", "", 1}, nil, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, got1, err := generateKey(tt.args.kty, tt.args.crv, tt.args.size)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("generateKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if reflect.TypeOf(got) != reflect.TypeOf(tt.wantType) {
|
|
||||||
t.Errorf("generateKey() got = %T, want %T", got, tt.wantType)
|
|
||||||
}
|
|
||||||
if reflect.TypeOf(got1) != reflect.TypeOf(tt.wantType1) {
|
|
||||||
t.Errorf("generateKey() got1 = %T, want %T", got1, tt.wantType1)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoftKMS_CreateDecrypter(t *testing.T) {
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pemBlock, err := pemutil.Serialize(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pemBlockPassword, err := pemutil.Serialize(privateKey, pemutil.WithPassword([]byte("pass")))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
ecdsaPK, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
ecdsaPemBlock, err := pemutil.Serialize(ecdsaPK)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
b, err := os.ReadFile("testdata/rsa.priv.pem")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
block, _ := pem.Decode(b)
|
|
||||||
block.Bytes, err = x509.DecryptPEMBlock(block, []byte("pass")) //nolint
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
keyFromFile, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateDecrypterRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want crypto.Decrypter
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"decrypter", args{&apiv1.CreateDecrypterRequest{Decrypter: privateKey}}, privateKey, false},
|
|
||||||
{"file", args{&apiv1.CreateDecrypterRequest{DecryptionKey: "testdata/rsa.priv.pem", Password: []byte("pass")}}, keyFromFile, false},
|
|
||||||
{"pem", args{&apiv1.CreateDecrypterRequest{DecryptionKeyPEM: pem.EncodeToMemory(pemBlock)}}, privateKey, false},
|
|
||||||
{"pem password", args{&apiv1.CreateDecrypterRequest{DecryptionKeyPEM: pem.EncodeToMemory(pemBlockPassword), Password: []byte("pass")}}, privateKey, false},
|
|
||||||
{"fail none", args{&apiv1.CreateDecrypterRequest{}}, nil, true},
|
|
||||||
{"fail missing", args{&apiv1.CreateDecrypterRequest{DecryptionKey: "testdata/missing"}}, nil, true},
|
|
||||||
{"fail bad pem", args{&apiv1.CreateDecrypterRequest{DecryptionKeyPEM: []byte("bad pem")}}, nil, true},
|
|
||||||
{"fail bad password", args{&apiv1.CreateDecrypterRequest{DecryptionKeyPEM: pem.EncodeToMemory(pemBlockPassword), Password: []byte("bad-pass")}}, nil, true},
|
|
||||||
{"fail not a decrypter (ecdsa key)", args{&apiv1.CreateDecrypterRequest{DecryptionKeyPEM: pem.EncodeToMemory(ecdsaPemBlock)}}, nil, true},
|
|
||||||
{"fail not a decrypter from file", args{&apiv1.CreateDecrypterRequest{DecryptionKey: "testdata/rsa.pub.pem"}}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &SoftKMS{}
|
|
||||||
got, err := k.CreateDecrypter(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("SoftKMS.CreateDecrypter(), error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("SoftKMS.CreateDecrypter() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
11
kms/softkms/testdata/cert.crt
vendored
11
kms/softkms/testdata/cert.crt
vendored
|
@ -1,11 +0,0 @@
|
||||||
-----BEGIN CERTIFICATE-----
|
|
||||||
MIIBpzCCAU2gAwIBAgIQWaY8KIDAfak8aYljelf8eTAKBggqhkjOPQQDAjAdMRsw
|
|
||||||
GQYDVQQDExJ0ZXN0LnNtYWxsc3RlcC5jb20wHhcNMjAwMTE2MDAwNDU4WhcNMjAw
|
|
||||||
MTE3MDAwNDU4WjAdMRswGQYDVQQDExJ0ZXN0LnNtYWxsc3RlcC5jb20wWTATBgcq
|
|
||||||
hkjOPQIBBggqhkjOPQMBBwNCAATlU8P9blFefSWuzYx2g215NJn6yHW95PXeFqQ9
|
|
||||||
kX1jNo1VmC6Oord3We37iM8QJT4QP9ZDUaAVmJUZSjd+W8H/o28wbTAOBgNVHQ8B
|
|
||||||
Af8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQW
|
|
||||||
BBTn0wonKkm2lLRNYZrKhUukiynvqzAdBgNVHREEFjAUghJ0ZXN0LnNtYWxsc3Rl
|
|
||||||
cC5jb20wCgYIKoZIzj0EAwIDSAAwRQIhAJ5XqryBIY1X4fl/9l0isV69eQfA0Qo5
|
|
||||||
1mjervUcEnOWAiBsmN4frz5YVw7i4UXChVBeZLZfJOKvn5eyh2gEzoq1+w==
|
|
||||||
-----END CERTIFICATE-----
|
|
5
kms/softkms/testdata/cert.key
vendored
5
kms/softkms/testdata/cert.key
vendored
|
@ -1,5 +0,0 @@
|
||||||
-----BEGIN EC PRIVATE KEY-----
|
|
||||||
MHcCAQEEICB6lIrMa9fVQJtdAYS4qmdYQ1BHJsEQDx8zxL38gA8toAoGCCqGSM49
|
|
||||||
AwEHoUQDQgAE5VPD/W5RXn0lrs2MdoNteTSZ+sh1veT13hakPZF9YzaNVZgujqK3
|
|
||||||
d1nt+4jPECU+ED/WQ1GgFZiVGUo3flvB/w==
|
|
||||||
-----END EC PRIVATE KEY-----
|
|
8
kms/softkms/testdata/priv.pem
vendored
8
kms/softkms/testdata/priv.pem
vendored
|
@ -1,8 +0,0 @@
|
||||||
-----BEGIN EC PRIVATE KEY-----
|
|
||||||
Proc-Type: 4,ENCRYPTED
|
|
||||||
DEK-Info: AES-256-CBC,1fcec5dfbf3327f61bfe5ab6ae8a0626
|
|
||||||
|
|
||||||
V39b/pNHMbP80TXSHLsUY6UOTCzf3KwIxvj1e7S9brNMJJc9b3UiloMBJIYBkl00
|
|
||||||
NKI8JU4jSlcerR58DqsTHIELiX6a+RJLe3/iR2/5Gru+CmmWJ68jQu872WCgh6Ms
|
|
||||||
o8TzhyGx74ETmdKn5CdtylsnKMa9heW3tBLFAbNCgKc=
|
|
||||||
-----END EC PRIVATE KEY-----
|
|
4
kms/softkms/testdata/pub.pem
vendored
4
kms/softkms/testdata/pub.pem
vendored
|
@ -1,4 +0,0 @@
|
||||||
-----BEGIN PUBLIC KEY-----
|
|
||||||
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE5VPD/W5RXn0lrs2MdoNteTSZ+sh1
|
|
||||||
veT13hakPZF9YzaNVZgujqK3d1nt+4jPECU+ED/WQ1GgFZiVGUo3flvB/w==
|
|
||||||
-----END PUBLIC KEY-----
|
|
30
kms/softkms/testdata/rsa.priv.pem
vendored
30
kms/softkms/testdata/rsa.priv.pem
vendored
|
@ -1,30 +0,0 @@
|
||||||
-----BEGIN RSA PRIVATE KEY-----
|
|
||||||
Proc-Type: 4,ENCRYPTED
|
|
||||||
DEK-Info: AES-256-CBC,dff7bfd0e0163a4cd7ade8f68b966699
|
|
||||||
|
|
||||||
jtmOhr2zo244Oq2fVsShZAUoQZ1gi6Iwc4i0sReU66XP9CFkdvJasAfkjQGrbCEy
|
|
||||||
m2+r7W6aH+L3j/4sXcJe8h4UVnnC4DHCozmtqqFCq7cFS4TiVpco26wEVH5WLm7Y
|
|
||||||
3Ew/pL0k24E+Ycf+yV5c1tQXRlmsKubjwzrZtGZP2yn3Dxsu97mzOXAfx7r+DIKI
|
|
||||||
5a4S3m1/yXw76tt6Iho9h4huA25UUDHKUQvOGd5gmOKqJRV9djoyu85ODbmz5nt0
|
|
||||||
pB2EzdHOrefgd0rcQQPI1uFBWqASJxTn+uS7ZBP4rlCcs932lI1mPerMh1ujo51F
|
|
||||||
3aibrwhKE6kaJyOOnUbvyBnaiTb5i4WwTqx/jfsOsggXQb3UlxgDph48VXw8O2jF
|
|
||||||
CQmle+TR8yr1A14/Dno5Dd4cqPv6AmWWU2zolvLxKQixFcvjsyQYCDajWWRPkOgj
|
|
||||||
RTKXDqL1mpjrlDqcSXzemCWk6FzqdUQhimhFgARDRfRwwDeWQN5ua4a3gnem/cpA
|
|
||||||
ZS8J45H0ZC/CxGPfp+qx75n5a875+n4VMmCZerXPzEIj1CzS7D6BVAXTHJaNIB6S
|
|
||||||
0WNfQnftp09O2l6iXBE+MHt5bVxqt46+vgcceSu7Gsb3ZfD79vnQ7tR+wb+xmHKk
|
|
||||||
8rVcMrB+kDRXVguH/a3zUGYAEnb6hPkIJywJVD4G65oM+D9D67Mdka8wIMK48doV
|
|
||||||
my8a0MfT/9AidR6XJVxIkHlPsPzlxirm/NKF7oSlzurcvYcPAYnHYLW2uB8dyidq
|
|
||||||
1zB+3rxbSYCVqrhqzN4prydGvkIE3/+AJyIGn7uGSTSSyF6BC9APXQaHplRGKwLz
|
|
||||||
efOIMoEwXJ1DIcKmk9GB65xxrZxMu3Cclcbc4PgY4370G0PfCHuUQNQL2RUWCQn0
|
|
||||||
aax+qDiFg1LsLRaI75OaLJ+uKs6rRfytQMmFGqK/b6iVbktiYWMtrDJDo4OUTtZ6
|
|
||||||
LBBySH7sAFgI3IIxct2Fwg8X1J4kfHr9jWTLjMEIE2o8cyqvSQ8rdwA25MxRcn75
|
|
||||||
DGqSlGE6Sx0XhWCVUiZidVRSYGKmOmH9yw8cjKm17qL23t8Gwns4Xunl7V6YlTCG
|
|
||||||
BPw5f1jWCQ94TwvUSuHMPYoXlYwRoe+jfDAzp2AQwXqvWX5Qno5PKz9gQ5iYacZ/
|
|
||||||
k82fyPbk2XLDkPnaNJKnyiIc252O0WffUlX6Rlv3aF8ZgVvWfZbuHEK6g1W+IKSA
|
|
||||||
pXAQ+iZBl+fjs/wT0yZSNTB0P1InD9Ve536L94gxXoeMr6F0Eouk3J2R9qdFp0Av
|
|
||||||
31xylRKSmzUf87/sRxjy3FzSTjIal77y1euJoAEU/nShmNrAZ6B8wnlvHfVwbgmt
|
|
||||||
xWqxYIi/j/C8Led9uhEhX2WjPsO7ckGA41Tw6hZk/5hr4jmPoZQKHf9OauJFujMh
|
|
||||||
ybPRQ6SGZJaYQAgpEGHSHFm8lwf5/DcezdSMdzqAKBWJBv6MediMuS60wcJ0Tebk
|
|
||||||
rdLkNE4bsxfc889BkXBrSqfd+Auu5RcF/kF44gLL7oj4ojQyV44vLZbC4+liGThT
|
|
||||||
bhayYGV64hsY+zL03u5wVfF1Y+33/uc8o/0JjbfuW5AIdikVES/jnKKFXSTMNL69
|
|
||||||
-----END RSA PRIVATE KEY-----
|
|
9
kms/softkms/testdata/rsa.pub.pem
vendored
9
kms/softkms/testdata/rsa.pub.pem
vendored
|
@ -1,9 +0,0 @@
|
||||||
-----BEGIN PUBLIC KEY-----
|
|
||||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAn2Oh7/uWB5RH40la1a43
|
|
||||||
IRaLZ8EnJVw5DCKE3BUre8xflVY2wTIS7XHcY0fEGprtq7hzFKors9AIGGn2yGrf
|
|
||||||
bZX2I+1g+RtQ6cLL6koeLuhRDqCuae0lZPulWc5ixBmM9mpl4ARRcpQFldxFRhis
|
|
||||||
xUaHMx8VqdZjFSDc5CJHYYK1n2G5DyuzJCk6yOfyMpwxizZJF4IUyqV7zKmZv1z9
|
|
||||||
/Xd8X0ag7jRdaTBpupJ1WLaq7LlvyB4nr47JXXkLFbRIL1F/gTcPtg0tdEZiKnxs
|
|
||||||
VLKwOs3VjhEorUwhmVxr4NnNX/0tuOY1FJ0mx5jKLAevqLVwK2JIg/f3h7JcNxDy
|
|
||||||
tQIDAQAB
|
|
||||||
-----END PUBLIC KEY-----
|
|
|
@ -1,206 +0,0 @@
|
||||||
package sshagentkms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
"golang.org/x/crypto/ssh/agent"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SSHAgentKMS is a key manager that uses keys provided by ssh-agent
|
|
||||||
type SSHAgentKMS struct {
|
|
||||||
agentClient agent.Agent
|
|
||||||
}
|
|
||||||
|
|
||||||
// New returns a new SSHAgentKMS.
|
|
||||||
func New(ctx context.Context, opts apiv1.Options) (*SSHAgentKMS, error) {
|
|
||||||
socket := os.Getenv("SSH_AUTH_SOCK")
|
|
||||||
conn, err := net.Dial("unix", socket)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to open SSH_AUTH_SOCK")
|
|
||||||
}
|
|
||||||
|
|
||||||
agentClient := agent.NewClient(conn)
|
|
||||||
|
|
||||||
return &SSHAgentKMS{
|
|
||||||
agentClient: agentClient,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFromAgent initializes an SSHAgentKMS from a given agent, this method is
|
|
||||||
// used for testing purposes.
|
|
||||||
func NewFromAgent(ctx context.Context, opts apiv1.Options, agentClient agent.Agent) (*SSHAgentKMS, error) {
|
|
||||||
return &SSHAgentKMS{
|
|
||||||
agentClient: agentClient,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
apiv1.Register(apiv1.SSHAgentKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
|
||||||
return New(ctx, opts)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the agent. This is a noop for the SSHAgentKMS.
|
|
||||||
func (k *SSHAgentKMS) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WrappedSSHSigner is a utility type to wrap a ssh.Signer as a crypto.Signer
|
|
||||||
type WrappedSSHSigner struct {
|
|
||||||
Sshsigner ssh.Signer
|
|
||||||
}
|
|
||||||
|
|
||||||
// Public returns the agent public key. The type of this public key is
|
|
||||||
// *agent.Key.
|
|
||||||
func (s *WrappedSSHSigner) Public() crypto.PublicKey {
|
|
||||||
return s.Sshsigner.PublicKey()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign signs the given digest using the ssh agent and returns the signature.
|
|
||||||
func (s *WrappedSSHSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
|
|
||||||
sig, err := s.Sshsigner.Sign(rand, digest)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return sig.Blob, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWrappedSignerFromSSHSigner returns a new crypto signer wrapping the given
|
|
||||||
// one.
|
|
||||||
func NewWrappedSignerFromSSHSigner(signer ssh.Signer) crypto.Signer {
|
|
||||||
return &WrappedSSHSigner{signer}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *SSHAgentKMS) findKey(signingKey string) (target int, err error) {
|
|
||||||
if strings.HasPrefix(signingKey, "sshagentkms:") {
|
|
||||||
var key = strings.TrimPrefix(signingKey, "sshagentkms:")
|
|
||||||
|
|
||||||
l, err := k.agentClient.List()
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
for i, s := range l {
|
|
||||||
if s.Comment == key {
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return -1, errors.Errorf("SSHAgentKMS couldn't find %s", signingKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSigner returns a new signer configured with the given signing key.
|
|
||||||
func (k *SSHAgentKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
|
||||||
if req.Signer != nil {
|
|
||||||
return req.Signer, nil
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(req.SigningKey, "sshagentkms:") {
|
|
||||||
target, err := k.findKey(req.SigningKey)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
s, err := k.agentClient.Signers()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return NewWrappedSignerFromSSHSigner(s[target]), nil
|
|
||||||
}
|
|
||||||
// OK: We don't actually care about non-ssh certificates,
|
|
||||||
// but we can't disable it in step-ca so this code is copy-pasted from
|
|
||||||
// softkms just to keep step-ca happy.
|
|
||||||
var opts []pemutil.Options
|
|
||||||
if req.Password != nil {
|
|
||||||
opts = append(opts, pemutil.WithPassword(req.Password))
|
|
||||||
}
|
|
||||||
switch {
|
|
||||||
case len(req.SigningKeyPEM) != 0:
|
|
||||||
v, err := pemutil.ParseKey(req.SigningKeyPEM, opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sig, ok := v.(crypto.Signer)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("signingKeyPEM is not a crypto.Signer")
|
|
||||||
}
|
|
||||||
return sig, nil
|
|
||||||
case req.SigningKey != "":
|
|
||||||
v, err := pemutil.Read(req.SigningKey, opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sig, ok := v.(crypto.Signer)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("signingKey is not a crypto.Signer")
|
|
||||||
}
|
|
||||||
return sig, nil
|
|
||||||
default:
|
|
||||||
return nil, errors.New("failed to load softKMS: please define signingKeyPEM or signingKey")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKey generates a new key and returns both public and private key.
|
|
||||||
func (k *SSHAgentKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
|
||||||
return nil, errors.Errorf("SSHAgentKMS doesn't support generating keys")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPublicKey returns the public key from the file passed in the request name.
|
|
||||||
func (k *SSHAgentKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
|
||||||
var v crypto.PublicKey
|
|
||||||
if strings.HasPrefix(req.Name, "sshagentkms:") {
|
|
||||||
target, err := k.findKey(req.Name)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := k.agentClient.Signers()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sshPub := s[target].PublicKey()
|
|
||||||
|
|
||||||
sshPubBytes := sshPub.Marshal()
|
|
||||||
|
|
||||||
parsed, err := ssh.ParsePublicKey(sshPubBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedCryptoKey := parsed.(ssh.CryptoPublicKey)
|
|
||||||
|
|
||||||
// Then, we can call CryptoPublicKey() to get the actual crypto.PublicKey
|
|
||||||
v = parsedCryptoKey.CryptoPublicKey()
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
v, err = pemutil.Read(req.Name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch vv := v.(type) {
|
|
||||||
case *x509.Certificate:
|
|
||||||
return vv.PublicKey, nil
|
|
||||||
case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
|
|
||||||
return vv, nil
|
|
||||||
default:
|
|
||||||
return nil, errors.Errorf("unsupported public key type %T", v)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,609 +0,0 @@
|
||||||
package sshagentkms
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
"golang.org/x/crypto/ssh/agent"
|
|
||||||
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Some helpers with inspiration from crypto/ssh/agent/client_test.go
|
|
||||||
|
|
||||||
// startOpenSSHAgent executes ssh-agent, and returns an Agent interface to it.
|
|
||||||
func startOpenSSHAgent(t *testing.T) (client agent.Agent, socket string, cleanup func()) {
|
|
||||||
/* Always test with OpenSSHAgent
|
|
||||||
if testing.Short() {
|
|
||||||
// ssh-agent is not always available, and the key
|
|
||||||
// types supported vary by platform.
|
|
||||||
t.Skip("skipping test due to -short")
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
bin, err := exec.LookPath("ssh-agent")
|
|
||||||
if err != nil {
|
|
||||||
t.Skip("could not find ssh-agent")
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command(bin, "-s")
|
|
||||||
cmd.Env = []string{} // Do not let the user's environment influence ssh-agent behavior.
|
|
||||||
cmd.Stderr = new(bytes.Buffer)
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("%s failed: %v\n%s", strings.Join(cmd.Args, " "), err, cmd.Stderr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output looks like:
|
|
||||||
//
|
|
||||||
// SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
|
|
||||||
// SSH_AGENT_PID=15542; export SSH_AGENT_PID;
|
|
||||||
// echo Agent pid 15542;
|
|
||||||
|
|
||||||
fields := bytes.Split(out, []byte(";"))
|
|
||||||
line := bytes.SplitN(fields[0], []byte("="), 2)
|
|
||||||
line[0] = bytes.TrimLeft(line[0], "\n")
|
|
||||||
if string(line[0]) != "SSH_AUTH_SOCK" {
|
|
||||||
t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
|
|
||||||
}
|
|
||||||
socket = string(line[1])
|
|
||||||
|
|
||||||
line = bytes.SplitN(fields[2], []byte("="), 2)
|
|
||||||
line[0] = bytes.TrimLeft(line[0], "\n")
|
|
||||||
if string(line[0]) != "SSH_AGENT_PID" {
|
|
||||||
t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
|
|
||||||
}
|
|
||||||
pidStr := line[1]
|
|
||||||
pid, err := strconv.Atoi(string(pidStr))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Atoi(%q): %v", pidStr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := net.Dial("unix", string(socket))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("net.Dial: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ac := agent.NewClient(conn)
|
|
||||||
return ac, socket, func() {
|
|
||||||
proc, _ := os.FindProcess(pid)
|
|
||||||
if proc != nil {
|
|
||||||
proc.Kill()
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
os.RemoveAll(filepath.Dir(socket))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func startAgent(t *testing.T, sshagent agent.Agent) (client agent.Agent, cleanup func()) {
|
|
||||||
c1, c2, err := netPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("netPipe: %v", err)
|
|
||||||
}
|
|
||||||
go agent.ServeAgent(sshagent, c2)
|
|
||||||
|
|
||||||
return agent.NewClient(c1), func() {
|
|
||||||
c1.Close()
|
|
||||||
c2.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// startKeyringAgent uses Keyring to simulate a ssh-agent Server and returns a client.
|
|
||||||
func startKeyringAgent(t *testing.T) (client agent.Agent, cleanup func()) {
|
|
||||||
return startAgent(t, agent.NewKeyring())
|
|
||||||
}
|
|
||||||
|
|
||||||
type startTestAgentFunc func(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent)
|
|
||||||
|
|
||||||
func startTestOpenSSHAgent(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent) {
|
|
||||||
sshagent, _, cleanup := startOpenSSHAgent(t)
|
|
||||||
for _, keyToAdd := range keysToAdd {
|
|
||||||
err := sshagent.Add(keyToAdd)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("sshagent.add: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
|
|
||||||
//testAgentInterface(t, sshagent, key, cert, lifetimeSecs)
|
|
||||||
return sshagent
|
|
||||||
}
|
|
||||||
|
|
||||||
func startTestKeyringAgent(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent) {
|
|
||||||
sshagent, cleanup := startKeyringAgent(t)
|
|
||||||
for _, keyToAdd := range keysToAdd {
|
|
||||||
err := sshagent.Add(keyToAdd)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("sshagent.add: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
|
|
||||||
//testAgentInterface(t, agent, key, cert, lifetimeSecs)
|
|
||||||
return sshagent
|
|
||||||
}
|
|
||||||
|
|
||||||
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
|
|
||||||
// therefore is buffered (net.Pipe deadlocks if both sides start with
|
|
||||||
// a write.)
|
|
||||||
func netPipe() (net.Conn, net.Conn, error) {
|
|
||||||
listener, err := netListener()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
c1, err := net.Dial("tcp", listener.Addr().String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c2, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
c1.Close()
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c1, c2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// netListener creates a localhost network listener.
|
|
||||||
func netListener() (net.Listener, error) {
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
listener, err = net.Listen("tcp", "[::1]:0")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return listener, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
comment := "Key from OpenSSHAgent"
|
|
||||||
// Ensure we don't "inherit" any SSH_AUTH_SOCK
|
|
||||||
os.Unsetenv("SSH_AUTH_SOCK")
|
|
||||||
|
|
||||||
sshagent, socket, cleanup := startOpenSSHAgent(t)
|
|
||||||
|
|
||||||
os.Setenv("SSH_AUTH_SOCK", socket)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
os.Unsetenv("SSH_AUTH_SOCK")
|
|
||||||
cleanup()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test that we can't find any signers in the agent before we have loaded them
|
|
||||||
t.Run("No keys with OpenSSHAgent", func(t *testing.T) {
|
|
||||||
kms, err := New(context.Background(), apiv1.Options{})
|
|
||||||
if kms == nil || err != nil {
|
|
||||||
t.Errorf("New() = %v, %v", kms, err)
|
|
||||||
}
|
|
||||||
signer, err := kms.CreateSigner(&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment})
|
|
||||||
if err == nil || signer != nil {
|
|
||||||
t.Errorf("SSHAgentKMS.CreateSigner() error = \"%v\", signer = \"%v\"", err, signer)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Load ssh test fixtures
|
|
||||||
b, err := os.ReadFile("testdata/ssh")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
privateKey, err := ssh.ParseRawPrivateKey(b)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// And add that key to the agent
|
|
||||||
err = sshagent.Add(agent.AddedKey{PrivateKey: privateKey, Comment: comment})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("sshagent.add: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// And test that we can find it when it's loaded
|
|
||||||
t.Run("Keys with OpenSSHAgent", func(t *testing.T) {
|
|
||||||
kms, err := New(context.Background(), apiv1.Options{})
|
|
||||||
if kms == nil || err != nil {
|
|
||||||
t.Errorf("New() = %v, %v", kms, err)
|
|
||||||
}
|
|
||||||
signer, err := kms.CreateSigner(&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment})
|
|
||||||
if err != nil || signer == nil {
|
|
||||||
t.Errorf("SSHAgentKMS.CreateSigner() error = \"%v\", signer = \"%v\"", err, signer)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewFromAgent(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
opts apiv1.Options
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
sshagentstarter startTestAgentFunc
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok OpenSSHAgent", args{context.Background(), apiv1.Options{}}, startTestOpenSSHAgent, false},
|
|
||||||
{"ok KeyringAgent", args{context.Background(), apiv1.Options{}}, startTestKeyringAgent, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := NewFromAgent(tt.args.ctx, tt.args.opts, tt.sshagentstarter(t))
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("NewFromAgent() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got == nil {
|
|
||||||
t.Errorf("NewFromAgent() = %v", got)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSSHAgentKMS_Close(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &SSHAgentKMS{}
|
|
||||||
if err := k.Close(); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("SSHAgentKMS.Close() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSSHAgentKMS_CreateSigner(t *testing.T) {
|
|
||||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pemBlock, err := pemutil.Serialize(pk)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pemBlockPassword, err := pemutil.Serialize(pk, pemutil.WithPassword([]byte("pass")))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read and decode file using standard packages
|
|
||||||
b, err := os.ReadFile("testdata/priv.pem")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
block, _ := pem.Decode(b)
|
|
||||||
block.Bytes, err = x509.DecryptPEMBlock(block, []byte("pass")) //nolint
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pk2, err := x509.ParseECPrivateKey(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a public PEM
|
|
||||||
b, err = x509.MarshalPKIXPublicKey(pk.Public())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
pub := pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "PUBLIC KEY",
|
|
||||||
Bytes: b,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Load ssh test fixtures
|
|
||||||
sshPubKeyStr, err := os.ReadFile("testdata/ssh.pub")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
_, comment, _, _, err := ssh.ParseAuthorizedKey(sshPubKeyStr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
b, err = os.ReadFile("testdata/ssh")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
privateKey, err := ssh.ParseRawPrivateKey(b)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
sshPrivateKey, err := ssh.NewSignerFromKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
wrappedSSHPrivateKey := NewWrappedSignerFromSSHSigner(sshPrivateKey)
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateSignerRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want crypto.Signer
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"signer", args{&apiv1.CreateSignerRequest{Signer: pk}}, pk, false},
|
|
||||||
{"pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlock)}}, pk, false},
|
|
||||||
{"pem password", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlockPassword), Password: []byte("pass")}}, pk, false},
|
|
||||||
{"file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("pass")}}, pk2, false},
|
|
||||||
{"sshagent", args{&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment}}, wrappedSSHPrivateKey, false},
|
|
||||||
{"sshagent Nonexistant", args{&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:Nonexistant"}}, nil, true},
|
|
||||||
{"fail", args{&apiv1.CreateSignerRequest{}}, nil, true},
|
|
||||||
{"fail bad pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: []byte("bad pem")}}, nil, true},
|
|
||||||
{"fail bad password", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("bad-pass")}}, nil, true},
|
|
||||||
{"fail not a signer", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pub}}, nil, true},
|
|
||||||
{"fail not a signer from file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/pub.pem"}}, nil, true},
|
|
||||||
{"fail missing", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/missing"}}, nil, true},
|
|
||||||
}
|
|
||||||
starters := []struct {
|
|
||||||
name string
|
|
||||||
starter startTestAgentFunc
|
|
||||||
}{
|
|
||||||
{"startTestOpenSSHAgent", startTestOpenSSHAgent},
|
|
||||||
{"startTestKeyringAgent", startTestKeyringAgent},
|
|
||||||
}
|
|
||||||
for _, starter := range starters {
|
|
||||||
k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t, agent.AddedKey{PrivateKey: privateKey, Comment: comment}))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(starter.name+"/"+tt.name, func(t *testing.T) {
|
|
||||||
got, err := k.CreateSigner(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("SSHAgentKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// nolint:gocritic
|
|
||||||
switch s := got.(type) {
|
|
||||||
case *WrappedSSHSigner:
|
|
||||||
gotPkS := s.Sshsigner.PublicKey().(*agent.Key).String() + "\n"
|
|
||||||
wantPkS := string(sshPubKeyStr)
|
|
||||||
if !reflect.DeepEqual(gotPkS, wantPkS) {
|
|
||||||
t.Errorf("SSHAgentKMS.CreateSigner() = %T, want %T", gotPkS, wantPkS)
|
|
||||||
t.Errorf("SSHAgentKMS.CreateSigner() = %v, want %v", gotPkS, wantPkS)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("SSHAgentKMS.CreateSigner() = %T, want %T", got, tt.want)
|
|
||||||
t.Errorf("SSHAgentKMS.CreateSigner() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
func restoreGenerateKey() func() {
|
|
||||||
oldGenerateKey := generateKey
|
|
||||||
return func() {
|
|
||||||
generateKey = oldGenerateKey
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*
|
|
||||||
func TestSSHAgentKMS_CreateKey(t *testing.T) {
|
|
||||||
fn := restoreGenerateKey()
|
|
||||||
defer fn()
|
|
||||||
|
|
||||||
p256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
edpub, edpriv, err := ed25519.GenerateKey(rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.CreateKeyRequest
|
|
||||||
}
|
|
||||||
type params struct {
|
|
||||||
kty string
|
|
||||||
crv string
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
generateKey func() (interface{}, interface{}, error)
|
|
||||||
want *apiv1.CreateKeyResponse
|
|
||||||
wantParams params
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"p256", args{&apiv1.CreateKeyRequest{Name: "p256", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
|
|
||||||
return p256.Public(), p256, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "p256", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
|
|
||||||
{"rsa", args{&apiv1.CreateKeyRequest{Name: "rsa3072", SignatureAlgorithm: apiv1.SHA256WithRSA}}, func() (interface{}, interface{}, error) {
|
|
||||||
return rsa2048.Public(), rsa2048, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "rsa3072", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 0}, false},
|
|
||||||
{"rsa2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 2048}}, func() (interface{}, interface{}, error) {
|
|
||||||
return rsa2048.Public(), rsa2048, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
|
|
||||||
{"rsaPSS2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSAPSS, Bits: 2048}}, func() (interface{}, interface{}, error) {
|
|
||||||
return rsa2048.Public(), rsa2048, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
|
|
||||||
{"ed25519", args{&apiv1.CreateKeyRequest{Name: "ed25519", SignatureAlgorithm: apiv1.PureEd25519}}, func() (interface{}, interface{}, error) {
|
|
||||||
return edpub, edpriv, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "ed25519", PublicKey: edpub, PrivateKey: edpriv, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: edpriv}}, params{"OKP", "Ed25519", 0}, false},
|
|
||||||
{"default", args{&apiv1.CreateKeyRequest{Name: "default"}}, func() (interface{}, interface{}, error) {
|
|
||||||
return p256.Public(), p256, nil
|
|
||||||
}, &apiv1.CreateKeyResponse{Name: "default", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
|
|
||||||
{"fail algorithm", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.SignatureAlgorithm(100)}}, func() (interface{}, interface{}, error) {
|
|
||||||
return p256.Public(), p256, nil
|
|
||||||
}, nil, params{}, true},
|
|
||||||
{"fail generate key", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
|
|
||||||
return nil, nil, fmt.Errorf("an error")
|
|
||||||
}, nil, params{"EC", "P-256", 0}, true},
|
|
||||||
{"fail no signer", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
|
|
||||||
return 1, 2, nil
|
|
||||||
}, nil, params{"EC", "P-256", 0}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
k := &SSHAgentKMS{}
|
|
||||||
generateKey = func(kty, crv string, size int) (interface{}, interface{}, error) {
|
|
||||||
if tt.wantParams.kty != kty {
|
|
||||||
t.Errorf("GenerateKey() kty = %s, want %s", kty, tt.wantParams.kty)
|
|
||||||
}
|
|
||||||
if tt.wantParams.crv != crv {
|
|
||||||
t.Errorf("GenerateKey() crv = %s, want %s", crv, tt.wantParams.crv)
|
|
||||||
}
|
|
||||||
if tt.wantParams.size != size {
|
|
||||||
t.Errorf("GenerateKey() size = %d, want %d", size, tt.wantParams.size)
|
|
||||||
}
|
|
||||||
return tt.generateKey()
|
|
||||||
}
|
|
||||||
|
|
||||||
got, err := k.CreateKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("SSHAgentKMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("SSHAgentKMS.CreateKey() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
|
|
||||||
b, err := os.ReadFile("testdata/pub.pem")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
block, _ := pem.Decode(b)
|
|
||||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load ssh test fixtures
|
|
||||||
b, err = os.ReadFile("testdata/ssh.pub")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
sshPubKey, comment, _, _, err := ssh.ParseAuthorizedKey(b)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
b, err = os.ReadFile("testdata/ssh")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
// crypto.PrivateKey
|
|
||||||
sshPrivateKey, err := ssh.ParseRawPrivateKey(b)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
req *apiv1.GetPublicKeyRequest
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want crypto.PublicKey
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"key", args{&apiv1.GetPublicKeyRequest{Name: "testdata/pub.pem"}}, pub, false},
|
|
||||||
{"cert", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.crt"}}, pub, false},
|
|
||||||
{"sshagent", args{&apiv1.GetPublicKeyRequest{Name: "sshagentkms:" + comment}}, sshPubKey, false},
|
|
||||||
{"sshagent Nonexistant", args{&apiv1.GetPublicKeyRequest{Name: "sshagentkms:Nonexistant"}}, nil, true},
|
|
||||||
{"fail not exists", args{&apiv1.GetPublicKeyRequest{Name: "testdata/missing"}}, nil, true},
|
|
||||||
{"fail type", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.key"}}, nil, true},
|
|
||||||
}
|
|
||||||
starters := []struct {
|
|
||||||
name string
|
|
||||||
starter startTestAgentFunc
|
|
||||||
}{
|
|
||||||
{"startTestOpenSSHAgent", startTestOpenSSHAgent},
|
|
||||||
{"startTestKeyringAgent", startTestKeyringAgent},
|
|
||||||
}
|
|
||||||
for _, starter := range starters {
|
|
||||||
k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t, agent.AddedKey{PrivateKey: sshPrivateKey, Comment: comment}))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(starter.name+"/"+tt.name, func(t *testing.T) {
|
|
||||||
got, err := k.GetPublicKey(tt.args.req)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("SSHAgentKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// nolint:gocritic
|
|
||||||
switch tt.want.(type) {
|
|
||||||
case ssh.PublicKey:
|
|
||||||
// If we want a ssh.PublicKey, protote got to a
|
|
||||||
got, err = ssh.NewPublicKey(got)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("SSHAgentKMS.GetPublicKey() = %T, want %T", got, tt.want)
|
|
||||||
t.Errorf("SSHAgentKMS.GetPublicKey() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSSHAgentKMS_CreateKey(t *testing.T) {
|
|
||||||
starters := []struct {
|
|
||||||
name string
|
|
||||||
starter startTestAgentFunc
|
|
||||||
}{
|
|
||||||
{"startTestOpenSSHAgent", startTestOpenSSHAgent},
|
|
||||||
{"startTestKeyringAgent", startTestKeyringAgent},
|
|
||||||
}
|
|
||||||
for _, starter := range starters {
|
|
||||||
k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
t.Run(starter.name+"/CreateKey", func(t *testing.T) {
|
|
||||||
got, err := k.CreateKey(&apiv1.CreateKeyRequest{
|
|
||||||
Name: "sshagentkms:0",
|
|
||||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
||||||
})
|
|
||||||
if got != nil {
|
|
||||||
t.Error("SSHAgentKMS.CreateKey() shoudn't return a value")
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
t.Error("SSHAgentKMS.CreateKey() didn't return a value")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
11
kms/sshagentkms/testdata/cert.crt
vendored
11
kms/sshagentkms/testdata/cert.crt
vendored
|
@ -1,11 +0,0 @@
|
||||||
-----BEGIN CERTIFICATE-----
|
|
||||||
MIIBpzCCAU2gAwIBAgIQWaY8KIDAfak8aYljelf8eTAKBggqhkjOPQQDAjAdMRsw
|
|
||||||
GQYDVQQDExJ0ZXN0LnNtYWxsc3RlcC5jb20wHhcNMjAwMTE2MDAwNDU4WhcNMjAw
|
|
||||||
MTE3MDAwNDU4WjAdMRswGQYDVQQDExJ0ZXN0LnNtYWxsc3RlcC5jb20wWTATBgcq
|
|
||||||
hkjOPQIBBggqhkjOPQMBBwNCAATlU8P9blFefSWuzYx2g215NJn6yHW95PXeFqQ9
|
|
||||||
kX1jNo1VmC6Oord3We37iM8QJT4QP9ZDUaAVmJUZSjd+W8H/o28wbTAOBgNVHQ8B
|
|
||||||
Af8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQW
|
|
||||||
BBTn0wonKkm2lLRNYZrKhUukiynvqzAdBgNVHREEFjAUghJ0ZXN0LnNtYWxsc3Rl
|
|
||||||
cC5jb20wCgYIKoZIzj0EAwIDSAAwRQIhAJ5XqryBIY1X4fl/9l0isV69eQfA0Qo5
|
|
||||||
1mjervUcEnOWAiBsmN4frz5YVw7i4UXChVBeZLZfJOKvn5eyh2gEzoq1+w==
|
|
||||||
-----END CERTIFICATE-----
|
|
5
kms/sshagentkms/testdata/cert.key
vendored
5
kms/sshagentkms/testdata/cert.key
vendored
|
@ -1,5 +0,0 @@
|
||||||
-----BEGIN EC PRIVATE KEY-----
|
|
||||||
MHcCAQEEICB6lIrMa9fVQJtdAYS4qmdYQ1BHJsEQDx8zxL38gA8toAoGCCqGSM49
|
|
||||||
AwEHoUQDQgAE5VPD/W5RXn0lrs2MdoNteTSZ+sh1veT13hakPZF9YzaNVZgujqK3
|
|
||||||
d1nt+4jPECU+ED/WQ1GgFZiVGUo3flvB/w==
|
|
||||||
-----END EC PRIVATE KEY-----
|
|
8
kms/sshagentkms/testdata/priv.pem
vendored
8
kms/sshagentkms/testdata/priv.pem
vendored
|
@ -1,8 +0,0 @@
|
||||||
-----BEGIN EC PRIVATE KEY-----
|
|
||||||
Proc-Type: 4,ENCRYPTED
|
|
||||||
DEK-Info: AES-256-CBC,1fcec5dfbf3327f61bfe5ab6ae8a0626
|
|
||||||
|
|
||||||
V39b/pNHMbP80TXSHLsUY6UOTCzf3KwIxvj1e7S9brNMJJc9b3UiloMBJIYBkl00
|
|
||||||
NKI8JU4jSlcerR58DqsTHIELiX6a+RJLe3/iR2/5Gru+CmmWJ68jQu872WCgh6Ms
|
|
||||||
o8TzhyGx74ETmdKn5CdtylsnKMa9heW3tBLFAbNCgKc=
|
|
||||||
-----END EC PRIVATE KEY-----
|
|
4
kms/sshagentkms/testdata/pub.pem
vendored
4
kms/sshagentkms/testdata/pub.pem
vendored
|
@ -1,4 +0,0 @@
|
||||||
-----BEGIN PUBLIC KEY-----
|
|
||||||
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE5VPD/W5RXn0lrs2MdoNteTSZ+sh1
|
|
||||||
veT13hakPZF9YzaNVZgujqK3d1nt+4jPECU+ED/WQ1GgFZiVGUo3flvB/w==
|
|
||||||
-----END PUBLIC KEY-----
|
|
49
kms/sshagentkms/testdata/ssh
vendored
49
kms/sshagentkms/testdata/ssh
vendored
|
@ -1,49 +0,0 @@
|
||||||
-----BEGIN OPENSSH PRIVATE KEY-----
|
|
||||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAACFwAAAAdzc2gtcn
|
|
||||||
NhAAAAAwEAAQAAAgEAth/d7zRDbv567o46KT6YYqC/EVdDpZ8m0rzIdroJL+RHVDXNQ1pU
|
|
||||||
3lrC9IWfkyjX+YwO9jHGbraJ+CgonAkl36mtLzNC4645QGS2/WdFqRR6mQCz7v4G6nOaFN
|
|
||||||
SCeErMhg0fn4f7jdqXpd0hYozIpktRVNYcpi2RMmr8e/Kadr5EVQfbYZgdKIl1O6Ws9O3Q
|
|
||||||
1BhLGi9GipEstUTvjqxZzF7oUgWKH54j5eHNXdbFqKqnK8NNQmypNLYGDsTBQHG9zRs+o0
|
|
||||||
7C2foO9ddIO2OCarcBWZfGlY05k/ZhEmrEOONh2rSLhJwqw+EJgQeU0Poe/IqjFy7jnTRk
|
|
||||||
i+tee2elBYVvHYPSofZaBmX7i21s8eBRl/ZiFx3ip6E3M54mXvKZ7SuA2qq/YW0IeKyJ5D
|
|
||||||
SuL0+sRAyiSQ2Icsyb3YKv6LXojuJTmJ9Hg9v4+aOPxOQhvNfh3b7sIh/cmz1dq/babLyO
|
|
||||||
ORrbHKDxIJME7VPMspmddV9wJgB4Gu1eWOiR/Cuv6jqYWTfiWJDIoqZRD5nF1tFqKtZ5iA
|
|
||||||
qkflv4Kbo10tv6nTlXR6TWuPu2Z/pZpx+NN+7QxVUSlRgxb7RTVcHRvpgd0TNEXGduR8ar
|
|
||||||
WVDlNewOmf5KFroW1IX/yR1OvE5RsDixxcX7Ne+uSlq9hooy9V/Ip0ffcF/Kg0NJoPwrnI
|
|
||||||
MAAAdQrAxluqwMZboAAAAHc3NoLXJzYQAAAgEAth/d7zRDbv567o46KT6YYqC/EVdDpZ8m
|
|
||||||
0rzIdroJL+RHVDXNQ1pU3lrC9IWfkyjX+YwO9jHGbraJ+CgonAkl36mtLzNC4645QGS2/W
|
|
||||||
dFqRR6mQCz7v4G6nOaFNSCeErMhg0fn4f7jdqXpd0hYozIpktRVNYcpi2RMmr8e/Kadr5E
|
|
||||||
VQfbYZgdKIl1O6Ws9O3Q1BhLGi9GipEstUTvjqxZzF7oUgWKH54j5eHNXdbFqKqnK8NNQm
|
|
||||||
ypNLYGDsTBQHG9zRs+o07C2foO9ddIO2OCarcBWZfGlY05k/ZhEmrEOONh2rSLhJwqw+EJ
|
|
||||||
gQeU0Poe/IqjFy7jnTRki+tee2elBYVvHYPSofZaBmX7i21s8eBRl/ZiFx3ip6E3M54mXv
|
|
||||||
KZ7SuA2qq/YW0IeKyJ5DSuL0+sRAyiSQ2Icsyb3YKv6LXojuJTmJ9Hg9v4+aOPxOQhvNfh
|
|
||||||
3b7sIh/cmz1dq/babLyOORrbHKDxIJME7VPMspmddV9wJgB4Gu1eWOiR/Cuv6jqYWTfiWJ
|
|
||||||
DIoqZRD5nF1tFqKtZ5iAqkflv4Kbo10tv6nTlXR6TWuPu2Z/pZpx+NN+7QxVUSlRgxb7RT
|
|
||||||
VcHRvpgd0TNEXGduR8arWVDlNewOmf5KFroW1IX/yR1OvE5RsDixxcX7Ne+uSlq9hooy9V
|
|
||||||
/Ip0ffcF/Kg0NJoPwrnIMAAAADAQABAAACADQ4KONYQemGT+ssnqKKzxigbIhlVAEeA/yy
|
|
||||||
omvgZZf0xTrw/jzMnr7umS2RTrLcKCjmLrgKh5HhBug/Y31x5gkeVojNEuXDY6kB97HqtX
|
|
||||||
+IXqqWGAFzlroMkWZdlFc3YzMgeiu8yrTes1Kcd+EQ6ss7l0NS7P383L/vCxvi8MURQvh6
|
|
||||||
ez2dZubjmtiSZWgI9DKMEKSeX4SFoaML9AAdjNXbdJNoATWVm0djmgXI+f2liK80nWdpTo
|
|
||||||
7NjikX4y0+L6SqpigfAiGL4FQ++PgGTTOZ62or6YWh65twLl8ge8iv8bPKxqIsQNrPIHF9
|
|
||||||
of7VaKMSgTa5fAvsJNQ1lW6exiK1szJ+g+zrkHuOjDaEWyIZi24/xy6iDaT1sdcjTGPJAo
|
|
||||||
WqgC9hlZQKjOOZJgwqu/kxgcsOGaGb2MD/E4xJVMvPsWYLQ5WGdiakQkVhclpcr3e0d8nw
|
|
||||||
xvqCqLsasCSECKJK+k3ReqtOe6GlTSzIpFiOgFAuYp+ejRkX6bJ2DRaYkjoWWza2VCpIJC
|
|
||||||
uyK7B3r1cV+g5KzvT6B+7TxVqYERisjWNvdppF87Vtx7C0p8mDzpJYpPY+yao3vEcq104+
|
|
||||||
yXuaPGEDTkTWOUB2uUS+AD9CBjkrGYFab1DBJob+L/7jNgVgWmMw1Yj9SDwXO6YBfbkhCf
|
|
||||||
Irfmf9Ne5i1+2SpFWBAAABAQCud97O9xI2bMGVGfbDFiaPTYGaGZ0qurLtHPpCX/YFkdBh
|
|
||||||
Z3LG7psJ/4JhkmMI3RFGhMxpUR9K22T3P/UmUt01PrDwDUpcw1JRPVIGs9AV3+GsAyyE6X
|
|
||||||
MzYo+8LNcxaPjh6ECXAQLcd9g0NOCbiqrKURBEuIBkxTy8jsmmeUlDsLcs8QKCsObJ2ozO
|
|
||||||
ACuFG5Z/SUeB7nhHnRUnozE8KsEWAgpys37AnJc1cQR6ALloh23L46rsWbSN5UGRgZdaUo
|
|
||||||
tklsDRun3qtYkDC8dDbW2Iy5A7GUXBRIA3mDYf4GDEUQvuu5Q/A2Dsr0hVi2wNVWd5O5M0
|
|
||||||
NVhuCHJU355wbbUUAAABAQDuet4GZQImmqfj2xAMoHUfSK0WagtzynP2fOSIRtOKQ9UXJN
|
|
||||||
J1CrSeu93dNACYjXt10X5ZCdZ9x/75ltyZHSUBbT1eQzPD4Jq23EcJ9ECCc4tJMpdNpJyv
|
|
||||||
8ixfeTCX0m6XP7nDDLgkuYuNTj/NTqIWotHt8/R8BA9FfTchZE+ekqj3TTIac3buU294mO
|
|
||||||
/0KKGHtt+GPHSD+ES+W28KETiFcz5nSD7oUQPXEbvsJg5bOWt9kY6JBGiizJSsEuLIjcva
|
|
||||||
H3UQMx6U805NjoGwIiKJyKgcmDMWVbeH87XxV6sllE8UaLUxbcOBdhmF/uJlazQsbqmF7B
|
|
||||||
CJB/X7SXredw9BAAABAQDDgRzgXsvBH72PMetQpWGswXp6UVsdHUUEyDiJXc5xjiVOxAIw
|
|
||||||
+pwaBRQ/6WMMJvhpZ/IFN+pAYEW5e0q2eGMpc1or4kf5eTukwJSF6VZf1Hhti6TfiStPCf
|
|
||||||
KSz07jUFROahMC88BOSwHuCc66emWlsZDrXS+pht1O7yU96epTM/hT/e8Bfi+ZFCJnQoQ5
|
|
||||||
dZuONhOYUT32rFKGBwPhsi6pjMB54vqrW1xFJbwj4i4dHFzA7UUa79j7ToAs2g2q8odTCR
|
|
||||||
CLUxGJ+YOkti67taOuRbzlL9wlxLGT+G2Dai9Ymbt18rmXR+2vazE0xFigYHPZb2QXeLAS
|
|
||||||
u104cC7ouX7DAAAAFnNzaC50ZXN0LnNtYWxsc3RlcC5jb20BAgME
|
|
||||||
-----END OPENSSH PRIVATE KEY-----
|
|
1
kms/sshagentkms/testdata/ssh.pub
vendored
1
kms/sshagentkms/testdata/ssh.pub
vendored
|
@ -1 +0,0 @@
|
||||||
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQC2H93vNENu/nrujjopPphioL8RV0OlnybSvMh2ugkv5EdUNc1DWlTeWsL0hZ+TKNf5jA72McZuton4KCicCSXfqa0vM0LjrjlAZLb9Z0WpFHqZALPu/gbqc5oU1IJ4SsyGDR+fh/uN2pel3SFijMimS1FU1hymLZEyavx78pp2vkRVB9thmB0oiXU7paz07dDUGEsaL0aKkSy1RO+OrFnMXuhSBYofniPl4c1d1sWoqqcrw01CbKk0tgYOxMFAcb3NGz6jTsLZ+g7110g7Y4JqtwFZl8aVjTmT9mESasQ442HatIuEnCrD4QmBB5TQ+h78iqMXLuOdNGSL6157Z6UFhW8dg9Kh9loGZfuLbWzx4FGX9mIXHeKnoTczniZe8pntK4Daqr9hbQh4rInkNK4vT6xEDKJJDYhyzJvdgq/oteiO4lOYn0eD2/j5o4/E5CG81+HdvuwiH9ybPV2r9tpsvI45GtscoPEgkwTtU8yymZ11X3AmAHga7V5Y6JH8K6/qOphZN+JYkMiiplEPmcXW0Woq1nmICqR+W/gpujXS2/qdOVdHpNa4+7Zn+lmnH4037tDFVRKVGDFvtFNVwdG+mB3RM0RcZ25HxqtZUOU17A6Z/koWuhbUhf/JHU68TlGwOLHFxfs1765KWr2GijL1X8inR99wX8qDQ0mg/Cucgw== ssh.test.smallstep.com
|
|
1
kms/uri/testdata/pin.txt
vendored
1
kms/uri/testdata/pin.txt
vendored
|
@ -1 +0,0 @@
|
||||||
trim-this-pin
|
|
148
kms/uri/uri.go
148
kms/uri/uri.go
|
@ -1,148 +0,0 @@
|
||||||
package uri
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/hex"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"unicode"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// URI implements a parser for a URI format based on the the PKCS #11 URI Scheme
|
|
||||||
// defined in https://tools.ietf.org/html/rfc7512
|
|
||||||
//
|
|
||||||
// These URIs will be used to define the key names in a KMS.
|
|
||||||
type URI struct {
|
|
||||||
*url.URL
|
|
||||||
Values url.Values
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new URI from a scheme and key-value pairs.
|
|
||||||
func New(scheme string, values url.Values) *URI {
|
|
||||||
return &URI{
|
|
||||||
URL: &url.URL{
|
|
||||||
Scheme: scheme,
|
|
||||||
Opaque: strings.ReplaceAll(values.Encode(), "&", ";"),
|
|
||||||
},
|
|
||||||
Values: values,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFile creates an uri for a file.
|
|
||||||
func NewFile(path string) *URI {
|
|
||||||
return &URI{
|
|
||||||
URL: &url.URL{
|
|
||||||
Scheme: "file",
|
|
||||||
Path: path,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasScheme returns true if the given uri has the given scheme, false otherwise.
|
|
||||||
func HasScheme(scheme, rawuri string) bool {
|
|
||||||
u, err := url.Parse(rawuri)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return strings.EqualFold(u.Scheme, scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse returns the URI for the given string or an error.
|
|
||||||
func Parse(rawuri string) (*URI, error) {
|
|
||||||
u, err := url.Parse(rawuri)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "error parsing %s", rawuri)
|
|
||||||
}
|
|
||||||
if u.Scheme == "" {
|
|
||||||
return nil, errors.Errorf("error parsing %s: scheme is missing", rawuri)
|
|
||||||
}
|
|
||||||
// Starting with Go 1.17 url.ParseQuery returns an error using semicolon as
|
|
||||||
// separator.
|
|
||||||
v, err := url.ParseQuery(strings.ReplaceAll(u.Opaque, ";", "&"))
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "error parsing %s", rawuri)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &URI{
|
|
||||||
URL: u,
|
|
||||||
Values: v,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseWithScheme returns the URI for the given string only if it has the given
|
|
||||||
// scheme.
|
|
||||||
func ParseWithScheme(scheme, rawuri string) (*URI, error) {
|
|
||||||
u, err := Parse(rawuri)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if !strings.EqualFold(u.Scheme, scheme) {
|
|
||||||
return nil, errors.Errorf("error parsing %s: scheme not expected", rawuri)
|
|
||||||
}
|
|
||||||
return u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns the first value in the uri with the given key, it will return
|
|
||||||
// empty string if that field is not present.
|
|
||||||
func (u *URI) Get(key string) string {
|
|
||||||
v := u.Values.Get(key)
|
|
||||||
if v == "" {
|
|
||||||
v = u.URL.Query().Get(key)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetBool returns true if a given key has the value "true". It returns false
|
|
||||||
// otherwise.
|
|
||||||
func (u *URI) GetBool(key string) bool {
|
|
||||||
v := u.Values.Get(key)
|
|
||||||
if v == "" {
|
|
||||||
v = u.URL.Query().Get(key)
|
|
||||||
}
|
|
||||||
return strings.EqualFold(v, "true")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetEncoded returns the first value in the uri with the given key, it will
|
|
||||||
// return empty nil if that field is not present or is empty. If the return
|
|
||||||
// value is hex encoded it will decode it and return it.
|
|
||||||
func (u *URI) GetEncoded(key string) []byte {
|
|
||||||
v := u.Get(key)
|
|
||||||
if v == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if len(v)%2 == 0 {
|
|
||||||
if b, err := hex.DecodeString(v); err == nil {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return []byte(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pin returns the pin encoded in the url. It will read the pin from the
|
|
||||||
// pin-value or the pin-source attributes.
|
|
||||||
func (u *URI) Pin() string {
|
|
||||||
if value := u.Get("pin-value"); value != "" {
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
if path := u.Get("pin-source"); path != "" {
|
|
||||||
if b, err := readFile(path); err == nil {
|
|
||||||
return string(bytes.TrimRightFunc(b, unicode.IsSpace))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func readFile(path string) ([]byte, error) {
|
|
||||||
u, err := url.Parse(path)
|
|
||||||
if err == nil && (u.Scheme == "" || u.Scheme == "file") && u.Path != "" {
|
|
||||||
path = u.Path
|
|
||||||
}
|
|
||||||
b, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "error reading %s", path)
|
|
||||||
}
|
|
||||||
return b, nil
|
|
||||||
}
|
|
|
@ -1,62 +0,0 @@
|
||||||
//go:build go1.19
|
|
||||||
|
|
||||||
package uri
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/url"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParse(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
rawuri string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *URI
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{"yubikey:slot-id=9a"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a"},
|
|
||||||
Values: url.Values{"slot-id": []string{"9a"}},
|
|
||||||
}, false},
|
|
||||||
{"ok schema", args{"cloudkms:"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "cloudkms"},
|
|
||||||
Values: url.Values{},
|
|
||||||
}, false},
|
|
||||||
{"ok query", args{"yubikey:slot-id=9a;foo=bar?pin=123456&foo=bar"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a;foo=bar", RawQuery: "pin=123456&foo=bar"},
|
|
||||||
Values: url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}},
|
|
||||||
}, false},
|
|
||||||
{"ok file", args{"file:///tmp/ca.cert"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"},
|
|
||||||
Values: url.Values{},
|
|
||||||
}, false},
|
|
||||||
{"ok file simple", args{"file:/tmp/ca.cert"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert", OmitHost: true},
|
|
||||||
Values: url.Values{},
|
|
||||||
}, false},
|
|
||||||
{"ok file host", args{"file://tmp/ca.cert"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "file", Host: "tmp", Path: "/ca.cert"},
|
|
||||||
Values: url.Values{},
|
|
||||||
}, false},
|
|
||||||
{"fail schema", args{"cloudkms"}, nil, true},
|
|
||||||
{"fail parse", args{"yubi%key:slot-id=9a"}, nil, true},
|
|
||||||
{"fail scheme", args{"yubikey"}, nil, true},
|
|
||||||
{"fail parse opaque", args{"yubikey:slot-id=%ZZ"}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := Parse(tt.args.rawuri)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Parse() = %#v, want %#v", got.URL, tt.want.URL)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,62 +0,0 @@
|
||||||
//go:build !go1.19
|
|
||||||
|
|
||||||
package uri
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/url"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParse(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
rawuri string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *URI
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{"yubikey:slot-id=9a"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a"},
|
|
||||||
Values: url.Values{"slot-id": []string{"9a"}},
|
|
||||||
}, false},
|
|
||||||
{"ok schema", args{"cloudkms:"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "cloudkms"},
|
|
||||||
Values: url.Values{},
|
|
||||||
}, false},
|
|
||||||
{"ok query", args{"yubikey:slot-id=9a;foo=bar?pin=123456&foo=bar"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a;foo=bar", RawQuery: "pin=123456&foo=bar"},
|
|
||||||
Values: url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}},
|
|
||||||
}, false},
|
|
||||||
{"ok file", args{"file:///tmp/ca.cert"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"},
|
|
||||||
Values: url.Values{},
|
|
||||||
}, false},
|
|
||||||
{"ok file simple", args{"file:/tmp/ca.cert"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"},
|
|
||||||
Values: url.Values{},
|
|
||||||
}, false},
|
|
||||||
{"ok file host", args{"file://tmp/ca.cert"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "file", Host: "tmp", Path: "/ca.cert"},
|
|
||||||
Values: url.Values{},
|
|
||||||
}, false},
|
|
||||||
{"fail schema", args{"cloudkms"}, nil, true},
|
|
||||||
{"fail parse", args{"yubi%key:slot-id=9a"}, nil, true},
|
|
||||||
{"fail scheme", args{"yubikey"}, nil, true},
|
|
||||||
{"fail parse opaque", args{"yubikey:slot-id=%ZZ"}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := Parse(tt.args.rawuri)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Parse() = %#v, want %#v", got.URL, tt.want.URL)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,282 +0,0 @@
|
||||||
package uri
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/url"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
scheme string
|
|
||||||
values url.Values
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *URI
|
|
||||||
}{
|
|
||||||
{"ok", args{"yubikey", url.Values{"slot-id": []string{"9a"}}}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a"},
|
|
||||||
Values: url.Values{"slot-id": []string{"9a"}},
|
|
||||||
}},
|
|
||||||
{"ok multiple", args{"yubikey", url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}}}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "yubikey", Opaque: "foo=bar;slot-id=9a"},
|
|
||||||
Values: url.Values{
|
|
||||||
"slot-id": []string{"9a"},
|
|
||||||
"foo": []string{"bar"},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := New(tt.args.scheme, tt.args.values); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("New() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewFile(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
path string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *URI
|
|
||||||
}{
|
|
||||||
{"ok", args{"/tmp/ca.crt"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.crt"},
|
|
||||||
Values: url.Values(nil),
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := NewFile(tt.args.path); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("NewFile() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHasScheme(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
scheme string
|
|
||||||
rawuri string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"ok", args{"yubikey", "yubikey:slot-id=9a"}, true},
|
|
||||||
{"ok empty", args{"yubikey", "yubikey:"}, true},
|
|
||||||
{"ok letter case", args{"awsKMS", "AWSkms:key-id=abcdefg?foo=bar"}, true},
|
|
||||||
{"fail", args{"yubikey", "awskms:key-id=abcdefg"}, false},
|
|
||||||
{"fail parse", args{"yubikey", "yubi%key:slot-id=9a"}, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := HasScheme(tt.args.scheme, tt.args.rawuri); got != tt.want {
|
|
||||||
t.Errorf("HasScheme() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseWithScheme(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
scheme string
|
|
||||||
rawuri string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *URI
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{"yubikey", "yubikey:slot-id=9a"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a"},
|
|
||||||
Values: url.Values{"slot-id": []string{"9a"}},
|
|
||||||
}, false},
|
|
||||||
{"ok schema", args{"cloudkms", "cloudkms:"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "cloudkms"},
|
|
||||||
Values: url.Values{},
|
|
||||||
}, false},
|
|
||||||
{"ok file", args{"file", "file:///tmp/ca.cert"}, &URI{
|
|
||||||
URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"},
|
|
||||||
Values: url.Values{},
|
|
||||||
}, false},
|
|
||||||
{"fail parse", args{"yubikey", "yubikey"}, nil, true},
|
|
||||||
{"fail scheme", args{"yubikey", "awskms:slot-id=9a"}, nil, true},
|
|
||||||
{"fail schema", args{"cloudkms", "cloudkms"}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := ParseWithScheme(tt.args.scheme, tt.args.rawuri)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("ParseWithScheme() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("ParseWithScheme() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestURI_Get(t *testing.T) {
|
|
||||||
mustParse := func(s string) *URI {
|
|
||||||
u, err := Parse(s)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
key string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
uri *URI
|
|
||||||
args args
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"ok", mustParse("yubikey:slot-id=9a"), args{"slot-id"}, "9a"},
|
|
||||||
{"ok first", mustParse("yubikey:slot-id=9a;slot-id=9b"), args{"slot-id"}, "9a"},
|
|
||||||
{"ok multiple", mustParse("yubikey:slot-id=9a;foo=bar"), args{"foo"}, "bar"},
|
|
||||||
{"ok in query", mustParse("yubikey:slot-id=9a?foo=bar"), args{"foo"}, "bar"},
|
|
||||||
{"fail missing", mustParse("yubikey:slot-id=9a"), args{"foo"}, ""},
|
|
||||||
{"fail missing query", mustParse("yubikey:slot-id=9a?bar=zar"), args{"foo"}, ""},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := tt.uri.Get(tt.args.key); got != tt.want {
|
|
||||||
t.Errorf("URI.Get() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestURI_GetBool(t *testing.T) {
|
|
||||||
mustParse := func(s string) *URI {
|
|
||||||
u, err := Parse(s)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
key string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
uri *URI
|
|
||||||
args args
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"true", mustParse("azurekms:name=foo;vault=bar;hsm=true"), args{"hsm"}, true},
|
|
||||||
{"TRUE", mustParse("azurekms:name=foo;vault=bar;hsm=TRUE"), args{"hsm"}, true},
|
|
||||||
{"tRUe query", mustParse("azurekms:name=foo;vault=bar?hsm=tRUe"), args{"hsm"}, true},
|
|
||||||
{"false", mustParse("azurekms:name=foo;vault=bar;hsm=false"), args{"hsm"}, false},
|
|
||||||
{"false query", mustParse("azurekms:name=foo;vault=bar?hsm=false"), args{"hsm"}, false},
|
|
||||||
{"empty", mustParse("azurekms:name=foo;vault=bar;hsm=?bar=true"), args{"hsm"}, false},
|
|
||||||
{"missing", mustParse("azurekms:name=foo;vault=bar"), args{"hsm"}, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := tt.uri.GetBool(tt.args.key); got != tt.want {
|
|
||||||
t.Errorf("URI.GetBool() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestURI_GetEncoded(t *testing.T) {
|
|
||||||
mustParse := func(s string) *URI {
|
|
||||||
u, err := Parse(s)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
key string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
uri *URI
|
|
||||||
args args
|
|
||||||
want []byte
|
|
||||||
}{
|
|
||||||
{"ok", mustParse("yubikey:slot-id=9a"), args{"slot-id"}, []byte{0x9a}},
|
|
||||||
{"ok first", mustParse("yubikey:slot-id=9a9b;slot-id=9b"), args{"slot-id"}, []byte{0x9a, 0x9b}},
|
|
||||||
{"ok percent", mustParse("yubikey:slot-id=9a;foo=%9a%9b%9c"), args{"foo"}, []byte{0x9a, 0x9b, 0x9c}},
|
|
||||||
{"ok in query", mustParse("yubikey:slot-id=9a?foo=9a"), args{"foo"}, []byte{0x9a}},
|
|
||||||
{"ok in query percent", mustParse("yubikey:slot-id=9a?foo=%9a"), args{"foo"}, []byte{0x9a}},
|
|
||||||
{"ok missing", mustParse("yubikey:slot-id=9a"), args{"foo"}, nil},
|
|
||||||
{"ok missing query", mustParse("yubikey:slot-id=9a?bar=zar"), args{"foo"}, nil},
|
|
||||||
{"ok no hex", mustParse("yubikey:slot-id=09a?bar=zar"), args{"slot-id"}, []byte{'0', '9', 'a'}},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := tt.uri.GetEncoded(tt.args.key)
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("URI.GetEncoded() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestURI_Pin(t *testing.T) {
|
|
||||||
mustParse := func(s string) *URI {
|
|
||||||
u, err := Parse(s)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
uri *URI
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"from value", mustParse("pkcs11:id=%72%73?pin-value=0123456789"), "0123456789"},
|
|
||||||
{"from source", mustParse("pkcs11:id=%72%73?pin-source=testdata/pin.txt"), "trim-this-pin"},
|
|
||||||
{"from missing", mustParse("pkcs11:id=%72%73"), ""},
|
|
||||||
{"from source missing", mustParse("pkcs11:id=%72%73?pin-source=testdata/foo.txt"), ""},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := tt.uri.Pin(); got != tt.want {
|
|
||||||
t.Errorf("URI.Pin() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestURI_String(t *testing.T) {
|
|
||||||
mustParse := func(s string) *URI {
|
|
||||||
u, err := Parse(s)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
uri *URI
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"ok new", New("yubikey", url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}}), "yubikey:foo=bar;slot-id=9a"},
|
|
||||||
{"ok parse", mustParse("yubikey:slot-id=9a;foo=bar?bar=zar"), "yubikey:slot-id=9a;foo=bar?bar=zar"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := tt.uri.String(); got != tt.want {
|
|
||||||
t.Errorf("URI.String() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,322 +0,0 @@
|
||||||
//go:build cgo
|
|
||||||
// +build cgo
|
|
||||||
|
|
||||||
package yubikey
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/hex"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/go-piv/piv-go/piv"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
"github.com/smallstep/certificates/kms/uri"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Scheme is the scheme used in uris.
|
|
||||||
const Scheme = "yubikey"
|
|
||||||
|
|
||||||
// YubiKey implements the KMS interface on a YubiKey.
|
|
||||||
type YubiKey struct {
|
|
||||||
yk *piv.YubiKey
|
|
||||||
pin string
|
|
||||||
managementKey [24]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// New initializes a new YubiKey.
|
|
||||||
// TODO(mariano): only one card is currently supported.
|
|
||||||
func New(ctx context.Context, opts apiv1.Options) (*YubiKey, error) {
|
|
||||||
managementKey := piv.DefaultManagementKey
|
|
||||||
|
|
||||||
if opts.URI != "" {
|
|
||||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if v := u.Pin(); v != "" {
|
|
||||||
opts.Pin = v
|
|
||||||
}
|
|
||||||
if v := u.Get("management-key"); v != "" {
|
|
||||||
opts.ManagementKey = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated way to set configuration parameters.
|
|
||||||
if opts.ManagementKey != "" {
|
|
||||||
b, err := hex.DecodeString(opts.ManagementKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error decoding managementKey")
|
|
||||||
}
|
|
||||||
if len(b) != 24 {
|
|
||||||
return nil, errors.New("invalid managementKey: length is not 24 bytes")
|
|
||||||
}
|
|
||||||
copy(managementKey[:], b[:24])
|
|
||||||
}
|
|
||||||
|
|
||||||
cards, err := piv.Cards()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(cards) == 0 {
|
|
||||||
return nil, errors.New("error detecting yubikey: try removing and reconnecting the device")
|
|
||||||
}
|
|
||||||
|
|
||||||
yk, err := piv.Open(cards[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error opening yubikey")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &YubiKey{
|
|
||||||
yk: yk,
|
|
||||||
pin: opts.Pin,
|
|
||||||
managementKey: managementKey,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
apiv1.Register(apiv1.YubiKey, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
|
||||||
return New(ctx, opts)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadCertificate implements kms.CertificateManager and loads a certificate
|
|
||||||
// from the YubiKey.
|
|
||||||
func (k *YubiKey) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) {
|
|
||||||
slot, err := getSlot(req.Name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := k.yk.Certificate(slot)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error retrieving certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StoreCertificate implements kms.CertificateManager and stores a certificate
|
|
||||||
// in the YubiKey.
|
|
||||||
func (k *YubiKey) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
|
|
||||||
if req.Certificate == nil {
|
|
||||||
return errors.New("storeCertificateRequest 'Certificate' cannot be nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
slot, err := getSlot(req.Name)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = k.yk.SetCertificate(k.managementKey, slot, req.Certificate)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "error storing certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPublicKey returns the public key present in the YubiKey signature slot.
|
|
||||||
func (k *YubiKey) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
|
||||||
slot, err := getSlot(req.Name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pub, err := k.getPublicKey(slot)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return pub, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateKey generates a new key in the YubiKey and returns the public key.
|
|
||||||
func (k *YubiKey) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
|
||||||
alg, err := getSignatureAlgorithm(req.SignatureAlgorithm, req.Bits)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
slot, name, err := getSlotAndName(req.Name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pub, err := k.yk.GenerateKey(k.managementKey, slot, piv.Key{
|
|
||||||
Algorithm: alg,
|
|
||||||
PINPolicy: piv.PINPolicyAlways,
|
|
||||||
TouchPolicy: piv.TouchPolicyNever,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error generating key")
|
|
||||||
}
|
|
||||||
return &apiv1.CreateKeyResponse{
|
|
||||||
Name: name,
|
|
||||||
PublicKey: pub,
|
|
||||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
||||||
SigningKey: name,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSigner creates a signer using the key present in the YubiKey signature
|
|
||||||
// slot.
|
|
||||||
func (k *YubiKey) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
|
||||||
slot, err := getSlot(req.SigningKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pub, err := k.getPublicKey(slot)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
priv, err := k.yk.PrivateKey(slot, pub, piv.KeyAuth{
|
|
||||||
PIN: k.pin,
|
|
||||||
PINPolicy: piv.PINPolicyAlways,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error retrieving private key")
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, ok := priv.(crypto.Signer)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("private key is not a crypto.Signer")
|
|
||||||
}
|
|
||||||
return signer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close releases the connection to the YubiKey.
|
|
||||||
func (k *YubiKey) Close() error {
|
|
||||||
return errors.Wrap(k.yk.Close(), "error closing yubikey")
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPublicKey returns the public key on a slot. First it attempts to do
|
|
||||||
// attestation to get a certificate with the public key in it, if this succeeds
|
|
||||||
// means that the key was generated in the device. If not we'll try to get the
|
|
||||||
// key from a stored certificate in the same slot.
|
|
||||||
func (k *YubiKey) getPublicKey(slot piv.Slot) (crypto.PublicKey, error) {
|
|
||||||
cert, err := k.yk.Attest(slot)
|
|
||||||
if err != nil {
|
|
||||||
if cert, err = k.yk.Certificate(slot); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error retrieving public key")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cert.PublicKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// signatureAlgorithmMapping is a mapping between the step signature algorithm,
|
|
||||||
// and bits for RSA keys, with yubikey ones.
|
|
||||||
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]interface{}{
|
|
||||||
apiv1.UnspecifiedSignAlgorithm: piv.AlgorithmEC256,
|
|
||||||
apiv1.SHA256WithRSA: map[int]piv.Algorithm{
|
|
||||||
0: piv.AlgorithmRSA2048,
|
|
||||||
1024: piv.AlgorithmRSA1024,
|
|
||||||
2048: piv.AlgorithmRSA2048,
|
|
||||||
},
|
|
||||||
apiv1.SHA512WithRSA: map[int]piv.Algorithm{
|
|
||||||
0: piv.AlgorithmRSA2048,
|
|
||||||
1024: piv.AlgorithmRSA1024,
|
|
||||||
2048: piv.AlgorithmRSA2048,
|
|
||||||
},
|
|
||||||
apiv1.SHA256WithRSAPSS: map[int]piv.Algorithm{
|
|
||||||
0: piv.AlgorithmRSA2048,
|
|
||||||
1024: piv.AlgorithmRSA1024,
|
|
||||||
2048: piv.AlgorithmRSA2048,
|
|
||||||
},
|
|
||||||
apiv1.SHA512WithRSAPSS: map[int]piv.Algorithm{
|
|
||||||
0: piv.AlgorithmRSA2048,
|
|
||||||
1024: piv.AlgorithmRSA1024,
|
|
||||||
2048: piv.AlgorithmRSA2048,
|
|
||||||
},
|
|
||||||
apiv1.ECDSAWithSHA256: piv.AlgorithmEC256,
|
|
||||||
apiv1.ECDSAWithSHA384: piv.AlgorithmEC384,
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSignatureAlgorithm(alg apiv1.SignatureAlgorithm, bits int) (piv.Algorithm, error) {
|
|
||||||
v, ok := signatureAlgorithmMapping[alg]
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.Errorf("YubiKey does not support signature algorithm '%s'", alg)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v := v.(type) {
|
|
||||||
case piv.Algorithm:
|
|
||||||
return v, nil
|
|
||||||
case map[int]piv.Algorithm:
|
|
||||||
signatureAlgorithm, ok := v[bits]
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.Errorf("YubiKey does not support signature algorithm '%s' with '%d' bits", alg, bits)
|
|
||||||
}
|
|
||||||
return signatureAlgorithm, nil
|
|
||||||
default:
|
|
||||||
return 0, errors.Errorf("unexpected error: this should not happen")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var slotMapping = map[string]piv.Slot{
|
|
||||||
"9a": piv.SlotAuthentication,
|
|
||||||
"9c": piv.SlotSignature,
|
|
||||||
"9e": piv.SlotCardAuthentication,
|
|
||||||
"9d": piv.SlotKeyManagement,
|
|
||||||
"82": {Key: 0x82, Object: 0x5FC10D},
|
|
||||||
"83": {Key: 0x83, Object: 0x5FC10E},
|
|
||||||
"84": {Key: 0x84, Object: 0x5FC10F},
|
|
||||||
"85": {Key: 0x85, Object: 0x5FC110},
|
|
||||||
"86": {Key: 0x86, Object: 0x5FC111},
|
|
||||||
"87": {Key: 0x87, Object: 0x5FC112},
|
|
||||||
"88": {Key: 0x88, Object: 0x5FC113},
|
|
||||||
"89": {Key: 0x89, Object: 0x5FC114},
|
|
||||||
"8a": {Key: 0x8a, Object: 0x5FC115},
|
|
||||||
"8b": {Key: 0x8b, Object: 0x5FC116},
|
|
||||||
"8c": {Key: 0x8c, Object: 0x5FC117},
|
|
||||||
"8d": {Key: 0x8d, Object: 0x5FC118},
|
|
||||||
"8e": {Key: 0x8e, Object: 0x5FC119},
|
|
||||||
"8f": {Key: 0x8f, Object: 0x5FC11A},
|
|
||||||
"90": {Key: 0x90, Object: 0x5FC11B},
|
|
||||||
"91": {Key: 0x91, Object: 0x5FC11C},
|
|
||||||
"92": {Key: 0x92, Object: 0x5FC11D},
|
|
||||||
"93": {Key: 0x93, Object: 0x5FC11E},
|
|
||||||
"94": {Key: 0x94, Object: 0x5FC11F},
|
|
||||||
"95": {Key: 0x95, Object: 0x5FC120},
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSlot(name string) (piv.Slot, error) {
|
|
||||||
slot, _, err := getSlotAndName(name)
|
|
||||||
return slot, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSlotAndName(name string) (piv.Slot, string, error) {
|
|
||||||
if name == "" {
|
|
||||||
return piv.SlotSignature, "yubikey:slot-id=9c", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var slotID string
|
|
||||||
name = strings.ToLower(name)
|
|
||||||
if strings.HasPrefix(name, "yubikey:") {
|
|
||||||
u, err := url.Parse(name)
|
|
||||||
if err != nil {
|
|
||||||
return piv.Slot{}, "", errors.Wrapf(err, "error parsing '%s'", name)
|
|
||||||
}
|
|
||||||
v, err := url.ParseQuery(u.Opaque)
|
|
||||||
if err != nil {
|
|
||||||
return piv.Slot{}, "", errors.Wrapf(err, "error parsing '%s'", name)
|
|
||||||
}
|
|
||||||
if slotID = v.Get("slot-id"); slotID == "" {
|
|
||||||
return piv.Slot{}, "", errors.Wrapf(err, "error parsing '%s': slot-id is missing", name)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
slotID = name
|
|
||||||
}
|
|
||||||
|
|
||||||
s, ok := slotMapping[slotID]
|
|
||||||
if !ok {
|
|
||||||
return piv.Slot{}, "", errors.Errorf("unsupported slot-id '%s'", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
name = "yubikey:slot-id=" + url.QueryEscape(slotID)
|
|
||||||
return s, name, nil
|
|
||||||
}
|
|
|
@ -1,20 +0,0 @@
|
||||||
//go:build !cgo
|
|
||||||
// +build !cgo
|
|
||||||
|
|
||||||
package yubikey
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
apiv1.Register(apiv1.YubiKey, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
|
||||||
name := filepath.Base(os.Args[0])
|
|
||||||
return nil, errors.Errorf("unsupported kms type 'yubikey': %s is compiled without cgo support", name)
|
|
||||||
})
|
|
||||||
}
|
|
Loading…
Reference in a new issue