forked from TrueCloudLab/certificates
419 lines
11 KiB
Go
419 lines
11 KiB
Go
package provisioner
|
|
|
|
import (
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/smallstep/assert"
|
|
"go.step.sm/crypto/jose"
|
|
)
|
|
|
|
func TestCollection_Load(t *testing.T) {
|
|
p, err := generateJWK()
|
|
assert.FatalError(t, err)
|
|
byID := new(sync.Map)
|
|
byID.Store(p.GetID(), p)
|
|
byID.Store("string", "a-string")
|
|
|
|
type fields struct {
|
|
byID *sync.Map
|
|
}
|
|
type args struct {
|
|
id string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
want Interface
|
|
want1 bool
|
|
}{
|
|
{"ok", fields{byID}, args{p.GetID()}, p, true},
|
|
{"fail", fields{byID}, args{"fail"}, nil, false},
|
|
{"invalid", fields{byID}, args{"string"}, nil, false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
c := &Collection{
|
|
byID: tt.fields.byID,
|
|
}
|
|
got, got1 := c.Load(tt.args.id)
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("Collection.Load() got = %v, want %v", got, tt.want)
|
|
}
|
|
if got1 != tt.want1 {
|
|
t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCollection_LoadByToken(t *testing.T) {
|
|
p1, err := generateJWK()
|
|
assert.FatalError(t, err)
|
|
p2, err := generateJWK()
|
|
assert.FatalError(t, err)
|
|
p3, err := generateOIDC()
|
|
assert.FatalError(t, err)
|
|
p4, err := generateK8sSA(nil)
|
|
assert.FatalError(t, err)
|
|
|
|
byID := new(sync.Map)
|
|
byID.Store(p1.GetID(), p1)
|
|
byID.Store(p2.GetID(), p2)
|
|
byID.Store(p3.GetID(), p3)
|
|
byID.Store(p4.GetID(), p4)
|
|
byID.Store("string", "a-string")
|
|
|
|
byID2 := new(sync.Map)
|
|
byID2.Store(p1.GetID(), p1)
|
|
byID2.Store(p2.GetID(), p2)
|
|
byID2.Store(p3.GetID(), p3)
|
|
|
|
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
|
assert.FatalError(t, err)
|
|
token, err := generateSimpleToken(p1.Name, testAudiences.Sign[0], jwk)
|
|
assert.FatalError(t, err)
|
|
t1, c1, err := parseToken(token)
|
|
assert.FatalError(t, err)
|
|
|
|
jwk, err = decryptJSONWebKey(p2.EncryptedKey)
|
|
assert.FatalError(t, err)
|
|
token, err = generateSimpleToken(p2.Name, testAudiences.Sign[1], jwk)
|
|
assert.FatalError(t, err)
|
|
t2, c2, err := parseToken(token)
|
|
assert.FatalError(t, err)
|
|
|
|
token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0])
|
|
assert.FatalError(t, err)
|
|
t3, c3, err := parseToken(token)
|
|
assert.FatalError(t, err)
|
|
|
|
token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0])
|
|
assert.FatalError(t, err)
|
|
t4, c4, err := parseToken(token)
|
|
assert.FatalError(t, err)
|
|
|
|
jwk, err = jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
|
assert.FatalError(t, err)
|
|
token, err = generateK8sSAToken(jwk, nil)
|
|
assert.FatalError(t, err)
|
|
t5, c5, err := parseToken(token)
|
|
assert.FatalError(t, err)
|
|
|
|
type fields struct {
|
|
byID *sync.Map
|
|
audiences Audiences
|
|
}
|
|
type args struct {
|
|
token *jose.JSONWebToken
|
|
claims *jose.Claims
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
want Interface
|
|
want1 bool
|
|
}{
|
|
{"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true},
|
|
{"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true},
|
|
{"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true},
|
|
{"ok4", fields{byID, testAudiences}, args{t5, c5}, p4, true},
|
|
{"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false},
|
|
{"fail", fields{byID, Audiences{Sign: []string{"https://foo"}}}, args{t1, c1}, nil, false},
|
|
{"fail-no-k8sSa-provisioner", fields{byID2, testAudiences}, args{t5, c5}, nil, false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
c := &Collection{
|
|
byID: tt.fields.byID,
|
|
byTokenID: tt.fields.byID,
|
|
audiences: tt.fields.audiences,
|
|
}
|
|
got, got1 := c.LoadByToken(tt.args.token, tt.args.claims)
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("Collection.LoadByToken() got = %v, want %v", got, tt.want)
|
|
}
|
|
if got1 != tt.want1 {
|
|
t.Errorf("Collection.LoadByToken() got1 = %v, want %v", got1, tt.want1)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCollection_LoadByCertificate(t *testing.T) {
|
|
mustExtension := func(typ Type, name, credentialID string) pkix.Extension {
|
|
e := Extension{
|
|
Type: typ, Name: name, CredentialID: credentialID,
|
|
}
|
|
ext, err := e.ToExtension()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return ext
|
|
}
|
|
|
|
p1, err := generateJWK()
|
|
assert.FatalError(t, err)
|
|
p2, err := generateOIDC()
|
|
assert.FatalError(t, err)
|
|
p3, err := generateACME()
|
|
assert.FatalError(t, err)
|
|
|
|
byName := new(sync.Map)
|
|
byName.Store(p1.GetName(), p1)
|
|
byName.Store(p2.GetName(), p2)
|
|
byName.Store(p3.GetName(), p3)
|
|
|
|
ok1Cert := &x509.Certificate{
|
|
Extensions: []pkix.Extension{mustExtension(1, p1.Name, p1.Key.KeyID)},
|
|
}
|
|
ok2Cert := &x509.Certificate{
|
|
Extensions: []pkix.Extension{mustExtension(2, p2.Name, p2.ClientID)},
|
|
}
|
|
ok3Cert := &x509.Certificate{
|
|
Extensions: []pkix.Extension{mustExtension(TypeACME, p3.Name, "")},
|
|
}
|
|
notFoundCert := &x509.Certificate{
|
|
Extensions: []pkix.Extension{mustExtension(1, "foo", "bar")},
|
|
}
|
|
badCert := &x509.Certificate{
|
|
Extensions: []pkix.Extension{
|
|
{Id: StepOIDProvisioner, Critical: false, Value: []byte("foobar")},
|
|
},
|
|
}
|
|
|
|
type fields struct {
|
|
byName *sync.Map
|
|
audiences Audiences
|
|
}
|
|
type args struct {
|
|
cert *x509.Certificate
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
want Interface
|
|
want1 bool
|
|
}{
|
|
{"ok1", fields{byName, testAudiences}, args{ok1Cert}, p1, true},
|
|
{"ok2", fields{byName, testAudiences}, args{ok2Cert}, p2, true},
|
|
{"ok3", fields{byName, testAudiences}, args{ok3Cert}, p3, true},
|
|
{"noExtension", fields{byName, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
|
|
{"notFound", fields{byName, testAudiences}, args{notFoundCert}, nil, false},
|
|
{"badCert", fields{byName, testAudiences}, args{badCert}, nil, false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
c := &Collection{
|
|
byName: tt.fields.byName,
|
|
audiences: tt.fields.audiences,
|
|
}
|
|
got, got1 := c.LoadByCertificate(tt.args.cert)
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("Collection.LoadByCertificate() got = %v, want %v", got, tt.want)
|
|
}
|
|
if got1 != tt.want1 {
|
|
t.Errorf("Collection.LoadByCertificate() got1 = %v, want %v", got1, tt.want1)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCollection_LoadEncryptedKey(t *testing.T) {
|
|
c := NewCollection(testAudiences)
|
|
p1, err := generateJWK()
|
|
assert.FatalError(t, err)
|
|
assert.FatalError(t, c.Store(p1))
|
|
p2, err := generateOIDC()
|
|
assert.FatalError(t, err)
|
|
assert.FatalError(t, c.Store(p2))
|
|
|
|
// Add oidc in byKey.
|
|
// It should not happen.
|
|
p2KeyID := p2.keyStore.keySet.Keys[0].KeyID
|
|
c.byKey.Store(p2KeyID, p2)
|
|
|
|
type args struct {
|
|
keyID string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want string
|
|
want1 bool
|
|
}{
|
|
{"ok", args{p1.Key.KeyID}, p1.EncryptedKey, true},
|
|
{"oidc", args{p2KeyID}, "", false},
|
|
{"notFound", args{"not-found"}, "", false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, got1 := c.LoadEncryptedKey(tt.args.keyID)
|
|
if got != tt.want {
|
|
t.Errorf("Collection.LoadEncryptedKey() got = %v, want %v", got, tt.want)
|
|
}
|
|
if got1 != tt.want1 {
|
|
t.Errorf("Collection.LoadEncryptedKey() got1 = %v, want %v", got1, tt.want1)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCollection_Store(t *testing.T) {
|
|
c := NewCollection(testAudiences)
|
|
p1, err := generateJWK()
|
|
assert.FatalError(t, err)
|
|
p2, err := generateOIDC()
|
|
assert.FatalError(t, err)
|
|
|
|
type args struct {
|
|
p Interface
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
wantErr bool
|
|
}{
|
|
{"ok1", args{p1}, false},
|
|
{"ok2", args{p2}, false},
|
|
{"fail1", args{p1}, true},
|
|
{"fail2", args{p2}, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if err := c.Store(tt.args.p); (err != nil) != tt.wantErr {
|
|
t.Errorf("Collection.Store() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCollection_Find(t *testing.T) {
|
|
c, err := generateCollection(10, 10)
|
|
assert.FatalError(t, err)
|
|
|
|
trim := func(s string) string {
|
|
return strings.TrimLeft(s, "0")
|
|
}
|
|
toList := func(ps provisionerSlice) List {
|
|
l := List{}
|
|
for _, p := range ps {
|
|
l = append(l, p.provisioner)
|
|
}
|
|
return l
|
|
}
|
|
|
|
type args struct {
|
|
cursor string
|
|
limit int
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want List
|
|
want1 string
|
|
}{
|
|
{"all", args{"", DefaultProvisionersMax}, toList(c.sorted[0:20]), ""},
|
|
{"0 to 19", args{"", 20}, toList(c.sorted[0:20]), ""},
|
|
{"0 to 9", args{"", 10}, toList(c.sorted[0:10]), trim(c.sorted[10].uid)},
|
|
{"9 to 19", args{trim(c.sorted[10].uid), 10}, toList(c.sorted[10:20]), ""},
|
|
{"1", args{trim(c.sorted[1].uid), 1}, toList(c.sorted[1:2]), trim(c.sorted[2].uid)},
|
|
{"1 to 5", args{trim(c.sorted[1].uid), 4}, toList(c.sorted[1:5]), trim(c.sorted[5].uid)},
|
|
{"defaultLimit", args{"", 0}, toList(c.sorted[0:20]), ""},
|
|
{"overTheLimit", args{"", DefaultProvisionersMax + 1}, toList(c.sorted[0:20]), ""},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, got1 := c.Find(tt.args.cursor, tt.args.limit)
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("Collection.Find() got = %v, want %v", got, tt.want)
|
|
}
|
|
if got1 != tt.want1 {
|
|
t.Errorf("Collection.Find() got1 = %v, want %v", got1, tt.want1)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_matchesAudience(t *testing.T) {
|
|
type matchesTest struct {
|
|
a, b []string
|
|
exp bool
|
|
}
|
|
tests := map[string]matchesTest{
|
|
"false arg1 empty": {
|
|
a: []string{},
|
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
exp: false,
|
|
},
|
|
"false arg2 empty": {
|
|
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
b: []string{},
|
|
exp: false,
|
|
},
|
|
"false arg1,arg2 empty": {
|
|
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
b: []string{"step-gateway", "step-cli"},
|
|
exp: false,
|
|
},
|
|
"false": {
|
|
a: []string{"step-gateway", "step-cli"},
|
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
exp: false,
|
|
},
|
|
"true": {
|
|
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
exp: true,
|
|
},
|
|
"true,portsA": {
|
|
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
exp: true,
|
|
},
|
|
"true,portsB": {
|
|
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:9000/sign"},
|
|
exp: true,
|
|
},
|
|
"true,portsAB": {
|
|
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:8000/sign"},
|
|
exp: true,
|
|
},
|
|
}
|
|
for name, tc := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b))
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_stripPort(t *testing.T) {
|
|
type args struct {
|
|
rawurl string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want string
|
|
}{
|
|
{"with port", args{"https://ca.smallstep.com:9000/sign"}, "https://ca.smallstep.com/sign"},
|
|
{"with no port", args{"https://ca.smallstep.com/sign/"}, "https://ca.smallstep.com/sign/"},
|
|
{"bad url", args{"https://a bad url:9000"}, "https://a bad url:9000"},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if got := stripPort(tt.args.rawurl); got != tt.want {
|
|
t.Errorf("stripPort() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|