diff --git a/kms/pkcs11/other_test.go b/kms/pkcs11/other_test.go index 602e101c..47a9ff83 100644 --- a/kms/pkcs11/other_test.go +++ b/kms/pkcs11/other_test.go @@ -1,4 +1,4 @@ -// +build !softhsm2,!yubihsm2 +// +build !softhsm2,!yubihsm2,!nitrokey package pkcs11 @@ -9,16 +9,14 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" - "fmt" "io" "math/big" - "testing" "github.com/ThalesIgnite/crypto11" "github.com/pkg/errors" ) -func mustPKCS11(t *testing.T) *PKCS11 { +func mustPKCS11(t TBTesting) *PKCS11 { t.Helper() testModule = "Golang crypto" k := &PKCS11{ @@ -63,10 +61,7 @@ func (s *stubPKCS11) FindKeyPair(id, label []byte) (crypto11.Signer, error) { if id == nil && label == nil { return nil, errors.New("id and label cannot both be nil") } - i, ok := s.signerIndex[newKey(id, label, nil)] - fmt.Println(i, ok) - if ok { - + if i, ok := s.signerIndex[newKey(id, label, nil)]; ok { return s.signers[i], nil } return nil, nil diff --git a/kms/pkcs11/pkcs11.go b/kms/pkcs11/pkcs11.go index d45f8045..9ac806d8 100644 --- a/kms/pkcs11/pkcs11.go +++ b/kms/pkcs11/pkcs11.go @@ -10,6 +10,7 @@ import ( "encoding/hex" "fmt" "math/big" + "strconv" "github.com/ThalesIgnite/crypto11" "github.com/pkg/errors" @@ -17,6 +18,9 @@ import ( "github.com/smallstep/certificates/kms/uri" ) +// Scheme is the scheme used in uris. +const Scheme = "pkcs11" + // DefaultRSASize is the number of bits of a new RSA key if not bitsize has been // specified. const DefaultRSASize = 3072 @@ -46,14 +50,22 @@ type PKCS11 struct { func New(ctx context.Context, opts apiv1.Options) (*PKCS11, error) { var config crypto11.Config if opts.URI != "" { - u, err := uri.ParseWithScheme("pkcs11", opts.URI) + u, err := uri.ParseWithScheme(Scheme, opts.URI) if err != nil { return nil, err } + + config.Pin = u.Pin() config.Path = u.Get("module-path") config.TokenLabel = u.Get("token") config.TokenSerial = u.Get("serial") - config.Pin = u.Pin() + if v := u.Get("slot-id"); v != "" { + n, err := strconv.Atoi(v) + if err != nil { + return nil, errors.Wrap(err, "kms uri 'slot-id' is not valid") + } + config.SlotNumber = &n + } } if config.Pin == "" && opts.Pin != "" { config.Pin = opts.Pin @@ -62,12 +74,16 @@ func New(ctx context.Context, opts apiv1.Options) (*PKCS11, error) { switch { case config.Path == "": return nil, errors.New("kms uri 'module-path' are required") - case config.TokenLabel == "" && config.TokenSerial == "": - return nil, errors.New("kms uri 'token' or 'serial' are required") + case config.TokenLabel == "" && config.TokenSerial == "" && config.SlotNumber == nil: + return nil, errors.New("kms uri 'token', 'serial' or 'slot-id' are required") case config.Pin == "": return nil, errors.New("kms 'pin' cannot be empty") case config.TokenLabel != "" && config.TokenSerial != "": - return nil, errors.New("kms uri 'token' or 'serial' are mutually exclusive") + return nil, errors.New("kms uri 'token' and 'serial' are mutually exclusive") + case config.TokenLabel != "" && config.SlotNumber != nil: + return nil, errors.New("kms uri 'token' and 'slot-id' are mutually exclusive") + case config.TokenSerial != "" && config.SlotNumber != nil: + return nil, errors.New("kms uri 'serial' and 'slot-id' are mutually exclusive") } p11, err := p11Configure(&config) @@ -167,6 +183,16 @@ func (k *PKCS11) StoreCertificate(req *apiv1.StoreCertificateRequest) error { return errors.Wrap(err, "storeCertificate failed") } + cert, err := k.p11.FindCertificate(id, object, nil) + if err != nil { + return errors.Wrap(err, "storeCertificate failed") + } + if cert != nil { + return errors.Wrap(apiv1.ErrAlreadyExists{ + Message: req.Name + " already exists", + }, "storeCertificate failed") + } + if err := k.p11.ImportCertificateWithLabel(id, object, req.Certificate); err != nil { return errors.Wrap(err, "storeCertificate failed") } @@ -218,7 +244,7 @@ func toByte(s string) []byte { } func parseObject(rawuri string) ([]byte, []byte, error) { - u, err := uri.ParseWithScheme("pkcs11", rawuri) + u, err := uri.ParseWithScheme(Scheme, rawuri) if err != nil { return nil, nil, err } @@ -290,7 +316,7 @@ func findSigner(ctx P11, rawuri string) (crypto11.Signer, error) { } func findCertificate(ctx P11, rawuri string) (*x509.Certificate, error) { - u, err := uri.ParseWithScheme("pkcs11", rawuri) + u, err := uri.ParseWithScheme(Scheme, rawuri) if err != nil { return nil, err } diff --git a/kms/pkcs11/pkcs11_test.go b/kms/pkcs11/pkcs11_test.go index 3c4d0bc3..e69b5598 100644 --- a/kms/pkcs11/pkcs11_test.go +++ b/kms/pkcs11/pkcs11_test.go @@ -12,14 +12,30 @@ import ( "crypto/x509" "math/big" "reflect" + "strings" "testing" + "github.com/ThalesIgnite/crypto11" + "github.com/pkg/errors" "github.com/smallstep/certificates/kms/apiv1" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" ) func TestNew(t *testing.T) { + tmp := p11Configure + t.Cleanup(func() { + p11Configure = tmp + }) + + k := mustPKCS11(t) + p11Configure = func(config *crypto11.Config) (P11, error) { + if strings.Index(config.Path, "fail") >= 0 { + return nil, errors.New("an error") + } + return k.p11, nil + } + type args struct { ctx context.Context opts apiv1.Options @@ -29,7 +45,71 @@ func TestNew(t *testing.T) { args args want *PKCS11 wantErr bool - }{} + }{ + {"ok", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test?pin-value=password", + }}, k, false}, + {"ok with serial", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;serial=0123456789?pin-value=password", + }}, k, false}, + {"ok with slot-id", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;slot-id=0?pin-value=password", + }}, k, false}, + {"ok with pin", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test", + Pin: "passowrd", + }}, k, false}, + {"fail missing module", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:token=pkcs11-test", + Pin: "passowrd", + }}, nil, true}, + {"fail missing pin", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test", + }}, nil, true}, + {"fail missing token/serial/slot-id", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so", + Pin: "passowrd", + }}, nil, true}, + {"fail token+serial+slot-id", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;serial=0123456789;slot-id=0", + Pin: "passowrd", + }}, nil, true}, + {"fail token+serial", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;serial=0123456789", + Pin: "passowrd", + }}, nil, true}, + {"fail token+slot-id", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;slot-id=0", + Pin: "passowrd", + }}, nil, true}, + {"fail serial+slot-id", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;serial=0123456789;slot-id=0", + Pin: "passowrd", + }}, nil, true}, + {"fail slot-id", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;slot-id=x?pin-value=password", + }}, nil, true}, + {"fail scheme", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "foo:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test?pin-value=password", + }}, nil, true}, + {"fail configure", args{context.Background(), apiv1.Options{ + Type: "pkcs11", + URI: "pkcs11:module-path=/usr/local/lib/fail.so;token=pkcs11-test?pin-value=password", + }}, nil, true}, + } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := New(tt.args.ctx, tt.args.opts) @@ -104,8 +184,7 @@ func TestPKCS11_CreateKey(t *testing.T) { k := setupPKCS11(t) // Make sure to delete the created key - keyName := "pkcs11:id=7771;object=create-key" - k.DeleteKey(keyName) + k.DeleteKey(testObject) type args struct { req *apiv1.CreateKeyRequest @@ -118,140 +197,140 @@ func TestPKCS11_CreateKey(t *testing.T) { }{ // SoftHSM2 {"default", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &ecdsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"RSA SHA256WithRSA", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, SignatureAlgorithm: apiv1.SHA256WithRSA, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &rsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"RSA SHA384WithRSA", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, SignatureAlgorithm: apiv1.SHA384WithRSA, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &rsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"RSA SHA512WithRSA", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, SignatureAlgorithm: apiv1.SHA512WithRSA, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &rsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"RSA SHA256WithRSAPSS", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, SignatureAlgorithm: apiv1.SHA256WithRSAPSS, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &rsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"RSA SHA384WithRSAPSS", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, SignatureAlgorithm: apiv1.SHA384WithRSAPSS, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &rsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"RSA SHA512WithRSAPSS", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, SignatureAlgorithm: apiv1.SHA512WithRSAPSS, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &rsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"RSA 2048", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 2048, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &rsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"RSA 4096", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 4096, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &rsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"ECDSA P256", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, SignatureAlgorithm: apiv1.ECDSAWithSHA256, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &ecdsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"ECDSA P384", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, SignatureAlgorithm: apiv1.ECDSAWithSHA384, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &ecdsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"ECDSA P521", args{&apiv1.CreateKeyRequest{ - Name: keyName, + Name: testObject, SignatureAlgorithm: apiv1.ECDSAWithSHA512, }}, &apiv1.CreateKeyResponse{ - Name: keyName, + Name: testObject, PublicKey: &ecdsa.PublicKey{}, CreateSignerRequest: apiv1.CreateSignerRequest{ - SigningKey: keyName, + SigningKey: testObject, }, }, false}, {"fail name", args{&apiv1.CreateKeyRequest{ Name: "", }}, nil, true}, {"fail bits", args{&apiv1.CreateKeyRequest{ - Name: "pkcs11:id=9999;object=rsa-create-key", + Name: "pkcs11:id=9999;object=create-key", Bits: -1, SignatureAlgorithm: apiv1.SHA256WithRSAPSS, }}, nil, true}, {"fail ed25519", args{&apiv1.CreateKeyRequest{ - Name: "pkcs11:id=9999;object=rsa-create-key", + Name: "pkcs11:id=9999;object=create-key", SignatureAlgorithm: apiv1.PureEd25519, }}, nil, true}, {"fail unknown", args{&apiv1.CreateKeyRequest{ - Name: "pkcs11:id=9999;object=rsa-create-key", + Name: "pkcs11:id=9999;object=create-key", SignatureAlgorithm: apiv1.SignatureAlgorithm(100), }}, nil, true}, {"fail uri", args{&apiv1.CreateKeyRequest{ @@ -262,6 +341,10 @@ func TestPKCS11_CreateKey(t *testing.T) { Name: "pkcs11:foo=bar", SignatureAlgorithm: apiv1.SHA256WithRSAPSS, }}, nil, true}, + {"fail already exists", args{&apiv1.CreateKeyRequest{ + Name: "pkcs11:id=7373;object=ecdsa-p256-key", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -322,8 +405,11 @@ func TestPKCS11_CreateSigner(t *testing.T) { SigningKey: "pkcs11:id=7371;object=rsa-key", }}, apiv1.SHA256WithRSA, crypto.SHA256, false}, {"RSA PSS", args{&apiv1.CreateSignerRequest{ - SigningKey: "pkcs11:id=7371;object=rsa-key", - }}, apiv1.SHA256WithRSA, crypto.SHA256, false}, + SigningKey: "pkcs11:id=7372;object=rsa-pss-key", + }}, apiv1.SHA256WithRSAPSS, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + Hash: crypto.SHA256, + }, false}, {"ECDSA P256", args{&apiv1.CreateSignerRequest{ SigningKey: "pkcs11:id=7373;object=ecdsa-p256-key", }}, apiv1.ECDSAWithSHA256, crypto.SHA256, false}, @@ -406,22 +492,31 @@ func TestPKCS11_LoadCertificate(t *testing.T) { wantErr bool }{ {"load", args{&apiv1.LoadCertificateRequest{ - Name: "pkcs11:id=7370;object=root", + Name: "pkcs11:id=7376;object=test-root", }}, getCertFn(0, 0), false}, {"load by id", args{&apiv1.LoadCertificateRequest{ - Name: "pkcs11:id=7370", + Name: "pkcs11:id=7376", }}, getCertFn(0, 0), false}, {"load by label", args{&apiv1.LoadCertificateRequest{ - Name: "pkcs11:object=root", + Name: "pkcs11:object=test-root", + }}, getCertFn(0, 0), false}, + {"load by serial", args{&apiv1.LoadCertificateRequest{ + Name: "pkcs11:serial=64", }}, getCertFn(0, 0), false}, {"fail missing", args{&apiv1.LoadCertificateRequest{ - Name: "pkcs11:id=9999;object=root", + Name: "pkcs11:id=9999;object=test-root", }}, nil, true}, {"fail name", args{&apiv1.LoadCertificateRequest{ Name: "", }}, nil, true}, {"fail uri", args{&apiv1.LoadCertificateRequest{ - Name: "pkcs11:id=xxxx;object=root", + Name: "pkcs11:id=xxxx;object=test-root", + }}, nil, true}, + {"fail scheme", args{&apiv1.LoadCertificateRequest{ + Name: "foo:id=7376;object=test-root", + }}, nil, true}, + {"fail serial", args{&apiv1.LoadCertificateRequest{ + Name: "pkcs11:serial=foo", }}, nil, true}, {"fail FindCertificate", args{&apiv1.LoadCertificateRequest{ Name: "pkcs11:foo=bar", @@ -460,6 +555,11 @@ func TestPKCS11_StoreCertificate(t *testing.T) { t.Fatalf("x509.CreateCertificate() error = %v", err) } + // Make sure to delete the created certificate + t.Cleanup(func() { + k.DeleteCertificate(testObject) + }) + type args struct { req *apiv1.StoreCertificateRequest } @@ -469,19 +569,23 @@ func TestPKCS11_StoreCertificate(t *testing.T) { wantErr bool }{ {"ok", args{&apiv1.StoreCertificateRequest{ - Name: "pkcs11:id=7771;object=root", + Name: testObject, Certificate: cert, }}, false}, + {"fail already exists", args{&apiv1.StoreCertificateRequest{ + Name: testObject, + Certificate: cert, + }}, true}, {"fail name", args{&apiv1.StoreCertificateRequest{ Name: "", Certificate: cert, }}, true}, {"fail certificate", args{&apiv1.StoreCertificateRequest{ - Name: "pkcs11:id=7771;object=root", + Name: testObject, Certificate: nil, }}, true}, {"fail uri", args{&apiv1.StoreCertificateRequest{ - Name: "http:id=7771;object=root", + Name: "http:id=7770;object=create-cert", Certificate: cert, }}, true}, {"fail ImportCertificateWithLabel", args{&apiv1.StoreCertificateRequest{ @@ -504,9 +608,101 @@ func TestPKCS11_StoreCertificate(t *testing.T) { if !reflect.DeepEqual(got, cert) { t.Errorf("PKCS11.LoadCertificate() = %v, want %v", got, cert) } - if err := k.DeleteCertificate(tt.args.req.Name); err != nil { - t.Errorf("PKCS11.DeleteCertificate() error = %v", err) - } + } + }) + } +} + +func TestPKCS11_DeleteKey(t *testing.T) { + k := setupPKCS11(t) + + type args struct { + uri string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"delete", args{testObject}, false}, + {"delete by id", args{testObjectByID}, false}, + {"delete by label", args{testObjectByLabel}, false}, + {"delete missing", args{"pkcs11:id=9999;object=missing-key"}, false}, + {"fail name", args{""}, true}, + {"fail uri", args{"pkcs11:id=xxxx;object=missing-key"}, true}, + {"fail FindKeyPair", args{"pkcs11:foo=bar"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := k.CreateKey(&apiv1.CreateKeyRequest{ + Name: testObject, + }); err != nil { + t.Fatalf("PKCS1.CreateKey() error = %v", err) + } + if err := k.DeleteKey(tt.args.uri); (err != nil) != tt.wantErr { + t.Errorf("PKCS11.DeleteKey() error = %v, wantErr %v", err, tt.wantErr) + } + if _, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: tt.args.uri, + }); err == nil { + t.Error("PKCS11.GetPublicKey() public key found and not expected") + } + // Make sure to delete the created one. + if err := k.DeleteKey(testObject); err != nil { + t.Errorf("PKCS11.DeleteKey() error = %v", err) + } + }) + } +} + +func TestPKCS11_DeleteCertificate(t *testing.T) { + k := setupPKCS11(t) + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("ed25519.GenerateKey() error = %v", err) + } + + cert, err := generateCertificate(pub, priv) + if err != nil { + t.Fatalf("x509.CreateCertificate() error = %v", err) + } + + type args struct { + uri string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"delete", args{testObject}, false}, + {"delete by id", args{testObjectByID}, false}, + {"delete by label", args{testObjectByLabel}, false}, + {"delete missing", args{"pkcs11:id=9999;object=missing-key"}, false}, + {"fail name", args{""}, true}, + {"fail uri", args{"pkcs11:id=xxxx;object=missing-key"}, true}, + {"fail DeleteCertificate", args{"pkcs11:foo=bar"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := k.StoreCertificate(&apiv1.StoreCertificateRequest{ + Name: testObject, + Certificate: cert, + }); err != nil { + t.Fatalf("PKCS11.StoreCertificate() error = %v", err) + } + if err := k.DeleteCertificate(tt.args.uri); (err != nil) != tt.wantErr { + t.Errorf("PKCS11.DeleteCertificate() error = %v, wantErr %v", err, tt.wantErr) + } + if _, err := k.LoadCertificate(&apiv1.LoadCertificateRequest{ + Name: tt.args.uri, + }); err == nil { + t.Error("PKCS11.LoadCertificate() certificate found and not expected") + } + // Make sure to delete the created one. + if err := k.DeleteCertificate(testObject); err != nil { + t.Errorf("PKCS11.DeleteCertificate() error = %v", err) } }) } diff --git a/kms/pkcs11/setup_test.go b/kms/pkcs11/setup_test.go index b912f79c..c9ff9311 100644 --- a/kms/pkcs11/setup_test.go +++ b/kms/pkcs11/setup_test.go @@ -8,17 +8,18 @@ import ( "crypto/x509" "crypto/x509/pkix" "math/big" - "testing" "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/kms/apiv1" ) var ( - testModule = "" - testKeys = []struct { + testModule = "" + testObject = "pkcs11:id=7370;object=test-name" + testObjectByID = "pkcs11:id=7370" + testObjectByLabel = "pkcs11:object=test-name" + testKeys = []struct { Name string SignatureAlgorithm apiv1.SignatureAlgorithm Bits int @@ -35,10 +36,19 @@ var ( Key string Certificates []*x509.Certificate }{ - {"pkcs11:id=7370;object=root", "pkcs11:id=7373;object=ecdsa-p256-key", nil}, + {"pkcs11:id=7376;object=test-root", "pkcs11:id=7373;object=ecdsa-p256-key", nil}, } ) +type TBTesting interface { + Helper() + Cleanup(f func()) + Log(args ...interface{}) + Errorf(format string, args ...interface{}) + Fatalf(format string, args ...interface{}) + Skipf(format string, args ...interface{}) +} + func generateCertificate(pub crypto.PublicKey, signer crypto.Signer) (*x509.Certificate, error) { now := time.Now() template := &x509.Certificate{ @@ -60,7 +70,7 @@ func generateCertificate(pub crypto.PublicKey, signer crypto.Signer) (*x509.Cert return x509.ParseCertificate(b) } -func setup(t *testing.T, k *PKCS11) { +func setup(t TBTesting, k *PKCS11) { t.Log("Running using", testModule) for _, tk := range testKeys { _, err := k.CreateKey(&apiv1.CreateKeyRequest{ @@ -91,15 +101,26 @@ func setup(t *testing.T, k *PKCS11) { if err := k.StoreCertificate(&apiv1.StoreCertificateRequest{ Name: c.Name, Certificate: cert, - }); err != nil { - t.Errorf("PKCS1.StoreCertificate() error = %v", err) + }); err != nil && !errors.Is(errors.Cause(err), apiv1.ErrAlreadyExists{ + Message: c.Name + " already exists", + }) { + t.Errorf("PKCS1.StoreCertificate() error = %+v", err) continue } testCerts[i].Certificates = append(testCerts[i].Certificates, cert) } } -func teardown(t *testing.T, k *PKCS11) { +func teardown(t TBTesting, k *PKCS11) { + testObjects := []string{testObject, testObjectByID, testObjectByLabel} + for _, name := range testObjects { + if err := k.DeleteKey(name); err != nil { + t.Errorf("PKCS11.DeleteKey() error = %v", err) + } + if err := k.DeleteCertificate(name); err != nil { + t.Errorf("PKCS11.DeleteCertificate() error = %v", err) + } + } for _, tk := range testKeys { if err := k.DeleteKey(tk.Name); err != nil { t.Errorf("PKCS11.DeleteKey() error = %v", err) @@ -112,7 +133,8 @@ func teardown(t *testing.T, k *PKCS11) { } } -func setupPKCS11(t *testing.T) *PKCS11 { +func setupPKCS11(t TBTesting) *PKCS11 { + t.Helper() k := mustPKCS11(t) t.Cleanup(func() { k.Close() diff --git a/kms/pkcs11/softhsm2_test.go b/kms/pkcs11/softhsm2_test.go index 4df99b1b..379d7a11 100644 --- a/kms/pkcs11/softhsm2_test.go +++ b/kms/pkcs11/softhsm2_test.go @@ -1,11 +1,10 @@ -// +build softhsm2,!yubihsm2 +// +build softhsm2 package pkcs11 import ( "runtime" "sync" - "testing" "github.com/ThalesIgnite/crypto11" ) @@ -13,14 +12,14 @@ import ( var softHSM2Once sync.Once // mustPKCS11 configures a *PKCS11 KMS to be used with SoftHSM2. To initialize -// this tests, we should run: +// these tests, we should run: // softhsm2-util --init-token --free \ // --token pkcs11-test --label pkcs11-test \ // --so-pin password --pin password // // To delete we should run: // softhsm2-util --delete-token --token pkcs11-test -func mustPKCS11(t *testing.T) *PKCS11 { +func mustPKCS11(t TBTesting) *PKCS11 { t.Helper() testModule = "SoftHSM2" if runtime.GOARCH != "amd64" { diff --git a/kms/pkcs11/yubihsm2_test.go b/kms/pkcs11/yubihsm2_test.go index f0e7d965..7d508872 100644 --- a/kms/pkcs11/yubihsm2_test.go +++ b/kms/pkcs11/yubihsm2_test.go @@ -1,11 +1,10 @@ -// +build !softhsm2,yubihsm2 +// +build yubihsm2 package pkcs11 import ( "runtime" "sync" - "testing" "github.com/ThalesIgnite/crypto11" ) @@ -13,9 +12,9 @@ import ( var yubiHSM2Once sync.Once // mustPKCS11 configures a *PKCS11 KMS to be used with YubiHSM2. To initialize -// this tests, we should run: +// these tests, we should run: // yubihsm-connector -d -func mustPKCS11(t *testing.T) *PKCS11 { +func mustPKCS11(t TBTesting) *PKCS11 { t.Helper() testModule = "YubiHSM2" if runtime.GOARCH != "amd64" {