package azurekms import ( "crypto" "crypto/ecdsa" "crypto/rand" "crypto/rsa" "encoding/base64" "io" "reflect" "testing" "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" "github.com/golang/mock/gomock" "go.step.sm/crypto/keyutil" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" ) func TestNewSigner(t *testing.T) { key, err := keyutil.GenerateDefaultSigner() if err != nil { t.Fatal(err) } pub := key.Public() jwk := createJWK(t, pub) client := mockClient(t) client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{ Key: jwk, }, nil) client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{ Key: jwk, }, nil) client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest) type args struct { client KeyVaultClient signingKey string } tests := []struct { name string args args want crypto.Signer wantErr bool }{ {"ok", args{client, "azurekms:vault=my-vault;id=my-key"}, &Signer{ client: client, vaultBaseURL: "https://my-vault.vault.azure.net/", name: "my-key", version: "", publicKey: pub, }, false}, {"ok with version", args{client, "azurekms:id=my-key;vault=my-vault?version=my-version"}, &Signer{ client: client, vaultBaseURL: "https://my-vault.vault.azure.net/", name: "my-key", version: "my-version", publicKey: pub, }, false}, {"fail GetKey", args{client, "azurekms:id=not-found;vault=my-vault?version=my-version"}, nil, true}, {"fail vault", args{client, "azurekms:id=not-found;vault="}, nil, true}, {"fail id", args{client, "azurekms:id=;vault=my-vault?version=my-version"}, nil, true}, {"fail scheme", args{client, "kms:id=not-found;vault=my-vault?version=my-version"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := NewSigner(tt.args.client, tt.args.signingKey) if (err != nil) != tt.wantErr { t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("NewSigner() = %v, want %v", got, tt.want) } }) } } func TestSigner_Public(t *testing.T) { key, err := keyutil.GenerateDefaultSigner() if err != nil { t.Fatal(err) } pub := key.Public() type fields struct { publicKey crypto.PublicKey } tests := []struct { name string fields fields want crypto.PublicKey }{ {"ok", fields{pub}, pub}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &Signer{ publicKey: tt.fields.publicKey, } if got := s.Public(); !reflect.DeepEqual(got, tt.want) { t.Errorf("Signer.Public() = %v, want %v", got, tt.want) } }) } } func TestSigner_Sign(t *testing.T) { sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) { key, err := keyutil.GenerateSigner(kty, crv, bits) if err != nil { t.Fatal(err) } h := opts.HashFunc().New() h.Write([]byte("random-data")) sum := h.Sum(nil) var sig, resultSig []byte if priv, ok := key.(*ecdsa.PrivateKey); ok { r, s, err := ecdsa.Sign(rand.Reader, priv, sum) if err != nil { t.Fatal(err) } curveBits := priv.Params().BitSize keyBytes := curveBits / 8 if curveBits%8 > 0 { keyBytes++ } rBytes := r.Bytes() rBytesPadded := make([]byte, keyBytes) copy(rBytesPadded[keyBytes-len(rBytes):], rBytes) sBytes := s.Bytes() sBytesPadded := make([]byte, keyBytes) copy(sBytesPadded[keyBytes-len(sBytes):], sBytes) resultSig = append(rBytesPadded, sBytesPadded...) var b cryptobyte.Builder b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { b.AddASN1BigInt(r) b.AddASN1BigInt(s) }) sig, err = b.Bytes() if err != nil { t.Fatal(err) } } else { sig, err = key.Sign(rand.Reader, sum, opts) if err != nil { t.Fatal(err) } resultSig = sig } return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig } p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256) p384, p384Digest, p386ResultSig, p384Sig := sign("EC", "P-384", 0, crypto.SHA384) p521, p521Digest, p521ResultSig, p521Sig := sign("EC", "P-521", 0, crypto.SHA512) rsaSHA256, rsaSHA256Digest, rsaSHA256ResultSig, rsaSHA256Sig := sign("RSA", "", 2048, crypto.SHA256) rsaSHA384, rsaSHA384Digest, rsaSHA384ResultSig, rsaSHA384Sig := sign("RSA", "", 2048, crypto.SHA384) rsaSHA512, rsaSHA512Digest, rsaSHA512ResultSig, rsaSHA512Sig := sign("RSA", "", 2048, crypto.SHA512) rsaPSSSHA256, rsaPSSSHA256Digest, rsaPSSSHA256ResultSig, rsaPSSSHA256Sig := sign("RSA", "", 2048, &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, Hash: crypto.SHA256, }) rsaPSSSHA384, rsaPSSSHA384Digest, rsaPSSSHA384ResultSig, rsaPSSSHA384Sig := sign("RSA", "", 2048, &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, Hash: crypto.SHA512, }) rsaPSSSHA512, rsaPSSSHA512Digest, rsaPSSSHA512ResultSig, rsaPSSSHA512Sig := sign("RSA", "", 2048, &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, Hash: crypto.SHA512, }) ed25519Key, err := keyutil.GenerateSigner("OKP", "Ed25519", 0) if err != nil { t.Fatal(err) } client := mockClient(t) expects := []struct { name string keyVersion string alg keyvault.JSONWebKeySignatureAlgorithm digest []byte result keyvault.KeyOperationResult err error }{ {"P-256", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{ Result: &p256ResultSig, }, nil}, {"P-384", "my-version", keyvault.ES384, p384Digest, keyvault.KeyOperationResult{ Result: &p386ResultSig, }, nil}, {"P-521", "my-version", keyvault.ES512, p521Digest, keyvault.KeyOperationResult{ Result: &p521ResultSig, }, nil}, {"RSA SHA256", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{ Result: &rsaSHA256ResultSig, }, nil}, {"RSA SHA384", "", keyvault.RS384, rsaSHA384Digest, keyvault.KeyOperationResult{ Result: &rsaSHA384ResultSig, }, nil}, {"RSA SHA512", "", keyvault.RS512, rsaSHA512Digest, keyvault.KeyOperationResult{ Result: &rsaSHA512ResultSig, }, nil}, {"RSA-PSS SHA256", "", keyvault.PS256, rsaPSSSHA256Digest, keyvault.KeyOperationResult{ Result: &rsaPSSSHA256ResultSig, }, nil}, {"RSA-PSS SHA384", "", keyvault.PS384, rsaPSSSHA384Digest, keyvault.KeyOperationResult{ Result: &rsaPSSSHA384ResultSig, }, nil}, {"RSA-PSS SHA512", "", keyvault.PS512, rsaPSSSHA512Digest, keyvault.KeyOperationResult{ Result: &rsaPSSSHA512ResultSig, }, nil}, // Errors {"fail Sign", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{}, errTest}, {"fail sign length", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{ Result: &rsaSHA256ResultSig, }, nil}, } for _, e := range expects { value := base64.RawURLEncoding.EncodeToString(e.digest) client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{ Algorithm: e.alg, Value: &value, }).Return(e.result, e.err) } type fields struct { client KeyVaultClient vaultBaseURL string name string version string publicKey crypto.PublicKey } type args struct { rand io.Reader digest []byte opts crypto.SignerOpts } tests := []struct { name string fields fields args args want []byte wantErr bool }{ {"ok P-256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ rand.Reader, p256Digest[:], crypto.SHA256, }, p256Sig, false}, {"ok P-384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p384}, args{ rand.Reader, p384Digest[:], crypto.SHA384, }, p384Sig, false}, {"ok P-521", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p521}, args{ rand.Reader, p521Digest[:], crypto.SHA512, }, p521Sig, false}, {"ok RSA SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{ rand.Reader, rsaSHA256Digest[:], crypto.SHA256, }, rsaSHA256Sig, false}, {"ok RSA SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA384}, args{ rand.Reader, rsaSHA384Digest[:], crypto.SHA384, }, rsaSHA384Sig, false}, {"ok RSA SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA512}, args{ rand.Reader, rsaSHA512Digest[:], crypto.SHA512, }, rsaSHA512Sig, false}, {"ok RSA-PSS SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{ rand.Reader, rsaPSSSHA256Digest[:], &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, Hash: crypto.SHA256, }, }, rsaPSSSHA256Sig, false}, {"ok RSA-PSS SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA384}, args{ rand.Reader, rsaPSSSHA384Digest[:], &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: crypto.SHA384, }, }, rsaPSSSHA384Sig, false}, {"ok RSA-PSS SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA512}, args{ rand.Reader, rsaPSSSHA512Digest[:], &rsa.PSSOptions{ SaltLength: 64, Hash: crypto.SHA512, }, }, rsaPSSSHA512Sig, false}, {"fail Sign", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{ rand.Reader, rsaSHA256Digest[:], crypto.SHA256, }, nil, true}, {"fail sign length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ rand.Reader, p256Digest[:], crypto.SHA256, }, nil, true}, {"fail RSA-PSS salt length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{ rand.Reader, rsaPSSSHA256Digest[:], &rsa.PSSOptions{ SaltLength: 64, Hash: crypto.SHA256, }, }, nil, true}, {"fail RSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{ rand.Reader, rsaSHA256Digest[:], crypto.SHA1, }, nil, true}, {"fail ECDSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ rand.Reader, p256Digest[:], crypto.MD5, }, nil, true}, {"fail Ed25519", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", ed25519Key}, args{ rand.Reader, []byte("message"), crypto.Hash(0), }, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &Signer{ client: tt.fields.client, vaultBaseURL: tt.fields.vaultBaseURL, name: tt.fields.name, version: tt.fields.version, publicKey: tt.fields.publicKey, } got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts) if (err != nil) != tt.wantErr { t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Signer.Sign() = %v, want %v", got, tt.want) } }) } }