forked from TrueCloudLab/certificates
191 lines
5.8 KiB
Go
191 lines
5.8 KiB
Go
package awskms
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"fmt"
|
|
"io"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/request"
|
|
"github.com/aws/aws-sdk-go/service/kms"
|
|
"github.com/smallstep/cli/crypto/pemutil"
|
|
)
|
|
|
|
func TestNewSigner(t *testing.T) {
|
|
okClient := getOKClient()
|
|
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
type args struct {
|
|
svc KeyManagementClient
|
|
signingKey string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want *Signer
|
|
wantErr bool
|
|
}{
|
|
{"ok", args{okClient, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, &Signer{
|
|
service: okClient,
|
|
keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
publicKey: key,
|
|
}, false},
|
|
{"fail parse", args{okClient, "awskms:key-id="}, nil, true},
|
|
{"fail preload", args{&MockClient{
|
|
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
return nil, fmt.Errorf("an error")
|
|
},
|
|
}, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true},
|
|
{"fail preload not der", args{&MockClient{
|
|
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
return &kms.GetPublicKeyOutput{
|
|
KeyId: input.KeyId,
|
|
PublicKey: []byte(publicKey),
|
|
}, nil
|
|
},
|
|
}, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := NewSigner(tt.args.svc, tt.args.signingKey)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSigner_Public(t *testing.T) {
|
|
okClient := getOKClient()
|
|
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
type fields struct {
|
|
service KeyManagementClient
|
|
keyID string
|
|
publicKey crypto.PublicKey
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
want crypto.PublicKey
|
|
}{
|
|
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, key},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &Signer{
|
|
service: tt.fields.service,
|
|
keyID: tt.fields.keyID,
|
|
publicKey: tt.fields.publicKey,
|
|
}
|
|
if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("Signer.Public() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSigner_Sign(t *testing.T) {
|
|
okClient := getOKClient()
|
|
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
type fields struct {
|
|
service KeyManagementClient
|
|
keyID string
|
|
publicKey crypto.PublicKey
|
|
}
|
|
type args struct {
|
|
rand io.Reader
|
|
digest []byte
|
|
opts crypto.SignerOpts
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
want []byte
|
|
wantErr bool
|
|
}{
|
|
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, signature, false},
|
|
{"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
|
|
{"fail key", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", []byte("key")}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
|
|
{"fail sign", fields{&MockClient{
|
|
signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
|
|
return nil, fmt.Errorf("an error")
|
|
},
|
|
}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &Signer{
|
|
service: tt.fields.service,
|
|
keyID: tt.fields.keyID,
|
|
publicKey: tt.fields.publicKey,
|
|
}
|
|
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_getSigningAlgorithm(t *testing.T) {
|
|
type args struct {
|
|
key crypto.PublicKey
|
|
opts crypto.SignerOpts
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want string
|
|
wantErr bool
|
|
}{
|
|
{"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, "RSASSA_PKCS1_V1_5_SHA_256", false},
|
|
{"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, "RSASSA_PKCS1_V1_5_SHA_384", false},
|
|
{"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, "RSASSA_PKCS1_V1_5_SHA_512", false},
|
|
{"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, "RSASSA_PSS_SHA_256", false},
|
|
{"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, "RSASSA_PSS_SHA_384", false},
|
|
{"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, "RSASSA_PSS_SHA_512", false},
|
|
{"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, "ECDSA_SHA_256", false},
|
|
{"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, "ECDSA_SHA_384", false},
|
|
{"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, "ECDSA_SHA_512", false},
|
|
{"fail type", args{[]byte("key"), crypto.SHA256}, "", true},
|
|
{"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, "", true},
|
|
{"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, "", true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := getSigningAlgorithm(tt.args.key, tt.args.opts)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("getSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if got != tt.want {
|
|
t.Errorf("getSigningAlgorithm() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|