Implement the kms.Decrypter with PKCS#11

This interface allows the use of SCEP with PKCS#11 modules.
This commit is contained in:
Mariano Cano 2021-12-16 18:30:09 -08:00
parent ab44fbfb3f
commit 5a32401d23
2 changed files with 104 additions and 2 deletions

View file

@ -7,6 +7,7 @@ import (
"context"
"crypto"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
"encoding/hex"
"fmt"
@ -142,8 +143,7 @@ func (k *PKCS11) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons
}, nil
}
// CreateSigner creates a signer using the key present in the PKCS#11 MODULE signature
// slot.
// CreateSigner creates a signer using a key present in the PKCS#11 module.
func (k *PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
if req.SigningKey == "" {
return nil, errors.New("createSignerRequest 'signingKey' cannot be empty")
@ -157,6 +157,27 @@ func (k *PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, er
return signer, nil
}
// CreateDecrypter creates a decrypter using a key present in the PKCS#11
// module.
func (k *PKCS11) CreateDecrypter(req *apiv1.CreateDecrypterRequest) (crypto.Decrypter, error) {
if req.DecryptionKey == "" {
return nil, errors.New("createDecrypterRequest 'decriptionKey' cannot be empty")
}
signer, err := findSigner(k.p11, req.DecryptionKey)
if err != nil {
return nil, errors.Wrap(err, "createDecrypterRequest failed")
}
// Only RSA keys will implement the Decrypter interface.
if _, ok := signer.Public().(*rsa.PublicKey); ok {
if dec, ok := signer.(crypto.Decrypter); ok {
return dec, nil
}
}
return nil, errors.New("createDecrypterRequest failed: signer does not implement crypto.Decrypter")
}
// LoadCertificate implements kms.CertificateManager and loads a certificate
// from the YubiKey.
func (k *PKCS11) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) {

View file

@ -4,6 +4,7 @@
package pkcs11
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
@ -491,6 +492,86 @@ func TestPKCS11_CreateSigner(t *testing.T) {
}
}
func TestPKCS11_CreateDecrypter(t *testing.T) {
k := setupPKCS11(t)
data := []byte("buggy-coheir-RUBRIC-rabbet-liberal-eaglet-khartoum-stagger")
type args struct {
req *apiv1.CreateDecrypterRequest
}
tests := []struct {
name string
args args
wantErr bool
}{
{"RSA", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:id=7371;object=rsa-key",
}}, false},
{"RSA PSS", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:id=7372;object=rsa-pss-key",
}}, false},
{"ECDSA P256", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:id=7373;object=ecdsa-p256-key",
}}, true},
{"ECDSA P384", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:id=7374;object=ecdsa-p384-key",
}}, true},
{"ECDSA P521", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:id=7375;object=ecdsa-p521-key",
}}, true},
{"fail DecryptionKey", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "",
}}, true},
{"fail uri", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "https:id=7375;object=ecdsa-p521-key",
}}, true},
{"fail FindKeyPair", args{&apiv1.CreateDecrypterRequest{
DecryptionKey: "pkcs11:foo=bar",
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := k.CreateDecrypter(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("PKCS11.CreateDecrypter() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
pub := got.Public().(*rsa.PublicKey)
// PKCS#1 v1.5
enc, err := rsa.EncryptPKCS1v15(rand.Reader, pub, data)
if err != nil {
t.Errorf("rsa.EncryptPKCS1v15() error = %v", err)
return
}
dec, err := got.Decrypt(rand.Reader, enc, nil)
if err != nil {
t.Errorf("PKCS1v15.Decrypt() error = %v", err)
} else if !bytes.Equal(dec, data) {
t.Errorf("PKCS1v15.Decrypt() failed got = %s, want = %s", dec, data)
}
// RSA-OAEP
enc, err = rsa.EncryptOAEP(crypto.SHA256.New(), rand.Reader, pub, data, []byte("label"))
if err != nil {
t.Errorf("rsa.EncryptOAEP() error = %v", err)
return
}
dec, err = got.Decrypt(rand.Reader, enc, &rsa.OAEPOptions{
Hash: crypto.SHA256,
Label: []byte("label"),
})
if err != nil {
t.Errorf("RSA-OAEP.Decrypt() error = %v", err)
} else if !bytes.Equal(dec, data) {
t.Errorf("RSA-OAEP.Decrypt() RSA-OAEP failed got = %s, want = %s", dec, data)
}
}
})
}
}
func TestPKCS11_LoadCertificate(t *testing.T) {
k := setupPKCS11(t)