forked from TrueCloudLab/certificates
253 lines
6.8 KiB
Go
253 lines
6.8 KiB
Go
package authority
|
|
|
|
import (
|
|
"context"
|
|
"crypto/x509"
|
|
"errors"
|
|
"net/http"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/smallstep/assert"
|
|
"github.com/smallstep/certificates/api/render"
|
|
"github.com/smallstep/certificates/authority/admin"
|
|
"github.com/smallstep/certificates/authority/provisioner"
|
|
"github.com/smallstep/certificates/db"
|
|
"go.step.sm/crypto/jose"
|
|
"go.step.sm/crypto/keyutil"
|
|
)
|
|
|
|
func TestGetEncryptedKey(t *testing.T) {
|
|
type ek struct {
|
|
a *Authority
|
|
kid string
|
|
err error
|
|
code int
|
|
}
|
|
tests := map[string]func(t *testing.T) *ek{
|
|
"ok": func(t *testing.T) *ek {
|
|
c, err := LoadConfiguration("../ca/testdata/ca.json")
|
|
assert.FatalError(t, err)
|
|
a, err := New(c)
|
|
assert.FatalError(t, err)
|
|
return &ek{
|
|
a: a,
|
|
kid: c.AuthorityConfig.Provisioners[1].(*provisioner.JWK).Key.KeyID,
|
|
}
|
|
},
|
|
"fail-not-found": func(t *testing.T) *ek {
|
|
c, err := LoadConfiguration("../ca/testdata/ca.json")
|
|
assert.FatalError(t, err)
|
|
a, err := New(c)
|
|
assert.FatalError(t, err)
|
|
return &ek{
|
|
a: a,
|
|
kid: "foo",
|
|
err: errors.New("encrypted key with kid foo was not found"),
|
|
code: http.StatusNotFound,
|
|
}
|
|
},
|
|
}
|
|
|
|
for name, genTestCase := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
tc := genTestCase(t)
|
|
|
|
ek, err := tc.a.GetEncryptedKey(tc.kid)
|
|
if err != nil {
|
|
if assert.NotNil(t, tc.err) {
|
|
sc, ok := render.AsStatusCodedError(err)
|
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
}
|
|
} else {
|
|
if assert.Nil(t, tc.err) {
|
|
val, ok := tc.a.provisioners.Load("mike:" + tc.kid)
|
|
assert.Fatal(t, ok)
|
|
p, ok := val.(*provisioner.JWK)
|
|
assert.Fatal(t, ok)
|
|
assert.Equals(t, p.EncryptedKey, ek)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type mockAdminDB struct {
|
|
admin.MockDB
|
|
MGetCertificateData func(string) (*db.CertificateData, error)
|
|
}
|
|
|
|
func (c *mockAdminDB) GetCertificateData(sn string) (*db.CertificateData, error) {
|
|
return c.MGetCertificateData(sn)
|
|
}
|
|
|
|
func TestGetProvisioners(t *testing.T) {
|
|
type gp struct {
|
|
a *Authority
|
|
err error
|
|
code int
|
|
}
|
|
tests := map[string]func(t *testing.T) *gp{
|
|
"ok": func(t *testing.T) *gp {
|
|
c, err := LoadConfiguration("../ca/testdata/ca.json")
|
|
assert.FatalError(t, err)
|
|
a, err := New(c)
|
|
assert.FatalError(t, err)
|
|
return &gp{a: a}
|
|
},
|
|
}
|
|
|
|
for name, genTestCase := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
tc := genTestCase(t)
|
|
|
|
ps, next, err := tc.a.GetProvisioners("", 0)
|
|
if err != nil {
|
|
if assert.NotNil(t, tc.err) {
|
|
sc, ok := render.AsStatusCodedError(err)
|
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
}
|
|
} else {
|
|
if assert.Nil(t, tc.err) {
|
|
assert.Equals(t, ps, tc.a.config.AuthorityConfig.Provisioners)
|
|
assert.Equals(t, "", next)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthority_LoadProvisionerByCertificate(t *testing.T) {
|
|
_, priv, err := keyutil.GenerateDefaultKeyPair()
|
|
assert.FatalError(t, err)
|
|
csr := getCSR(t, priv)
|
|
|
|
sign := func(a *Authority, extraOpts ...provisioner.SignOption) *x509.Certificate {
|
|
key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
|
assert.FatalError(t, err)
|
|
token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
|
opts, err := a.Authorize(ctx, token)
|
|
assert.FatalError(t, err)
|
|
opts = append(opts, extraOpts...)
|
|
certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
|
|
assert.FatalError(t, err)
|
|
return certs[0]
|
|
}
|
|
getProvisioner := func(a *Authority, name string) provisioner.Interface {
|
|
p, ok := a.provisioners.LoadByName(name)
|
|
if !ok {
|
|
t.Fatalf("provisioner %s does not exists", name)
|
|
}
|
|
return p
|
|
}
|
|
removeExtension := provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
|
for i, ext := range cert.ExtraExtensions {
|
|
if ext.Id.Equal(provisioner.StepOIDProvisioner) {
|
|
cert.ExtraExtensions = append(cert.ExtraExtensions[:i], cert.ExtraExtensions[i+1:]...)
|
|
break
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
|
|
a0 := testAuthority(t)
|
|
|
|
a1 := testAuthority(t)
|
|
a1.db = &db.MockAuthDB{
|
|
MUseToken: func(id, tok string) (bool, error) {
|
|
return true, nil
|
|
},
|
|
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
|
p, err := a1.LoadProvisionerByName("dev")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return &db.CertificateData{
|
|
Provisioner: &db.ProvisionerData{
|
|
ID: p.GetID(),
|
|
Name: p.GetName(),
|
|
Type: p.GetType().String(),
|
|
},
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
a2 := testAuthority(t)
|
|
a2.adminDB = &mockAdminDB{
|
|
MGetCertificateData: (func(s string) (*db.CertificateData, error) {
|
|
p, err := a2.LoadProvisionerByName("dev")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return &db.CertificateData{
|
|
Provisioner: &db.ProvisionerData{
|
|
ID: p.GetID(),
|
|
Name: p.GetName(),
|
|
Type: p.GetType().String(),
|
|
},
|
|
}, nil
|
|
}),
|
|
}
|
|
|
|
a3 := testAuthority(t)
|
|
a3.db = &db.MockAuthDB{
|
|
MUseToken: func(id, tok string) (bool, error) {
|
|
return true, nil
|
|
},
|
|
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
|
return &db.CertificateData{
|
|
Provisioner: &db.ProvisionerData{
|
|
ID: "foo", Name: "foo", Type: "foo",
|
|
},
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
a4 := testAuthority(t)
|
|
a4.adminDB = &mockAdminDB{
|
|
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
|
return &db.CertificateData{
|
|
Provisioner: &db.ProvisionerData{
|
|
ID: "foo", Name: "foo", Type: "foo",
|
|
},
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
type args struct {
|
|
crt *x509.Certificate
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
authority *Authority
|
|
args args
|
|
want provisioner.Interface
|
|
wantErr bool
|
|
}{
|
|
{"ok from certificate", a0, args{sign(a0)}, getProvisioner(a0, "step-cli"), false},
|
|
{"ok from db", a1, args{sign(a1)}, getProvisioner(a1, "dev"), false},
|
|
{"ok from admindb", a2, args{sign(a2)}, getProvisioner(a2, "dev"), false},
|
|
{"fail from certificate", a0, args{sign(a0, removeExtension)}, nil, true},
|
|
{"fail from db", a3, args{sign(a3, removeExtension)}, nil, true},
|
|
{"fail from admindb", a4, args{sign(a4, removeExtension)}, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := tt.authority.LoadProvisionerByCertificate(tt.args.crt)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("Authority.LoadProvisionerByCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("Authority.LoadProvisionerByCertificate() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|