diff --git a/authority/authority_test.go b/authority/authority_test.go index f952dfe4..1ef7c2d5 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -17,25 +17,25 @@ func testAuthority(t *testing.T) *Authority { clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) disableRenewal := true - p := []*provisioner.Provisioner{ - provisioner.New(&provisioner.JWK{ + p := provisioner.List{ + &provisioner.JWK{ Name: "Max", Type: "JWK", Key: maxjwk, - }), - provisioner.New(&provisioner.JWK{ + }, + &provisioner.JWK{ Name: "step-cli", Type: "JWK", Key: clijwk, - }), - provisioner.New(&provisioner.JWK{ + }, + &provisioner.JWK{ Name: "dev", Type: "JWK", Key: maxjwk, Claims: &provisioner.Claims{ DisableRenewal: &disableRenewal, }, - }), + }, } c := &Config{ Address: "127.0.0.1:443", @@ -114,24 +114,24 @@ func TestAuthorityNew(t *testing.T) { assert.True(t, auth.initOnce) assert.NotNil(t, auth.intermediateIdentity) for _, p := range tc.config.AuthorityConfig.Provisioners { - _p, ok := auth.provisioners.Load(p.ID()) + _p, ok := auth.provisioners.Load(p.GetID()) assert.True(t, ok) assert.Equals(t, p, _p) - if len(p.EncryptedKey) > 0 { - key, ok := auth.provisioners.LoadEncryptedKey(p.Key.KeyID) + if kid, encryptedKey, ok := p.GetEncryptedKey(); ok { + key, ok := auth.provisioners.LoadEncryptedKey(kid) assert.True(t, ok) - assert.Equals(t, p.EncryptedKey, key) + assert.Equals(t, encryptedKey, key) } } // sanity check - _, ok = auth.provisionerIDIndex.Load("fooo") + _, ok = auth.provisioners.Load("fooo") assert.False(t, ok) - assert.Equals(t, auth.audiences, []string{ - "step-certificate-authority", - "https://127.0.0.1/sign", - "https://127.0.0.1/1.0/sign", - }) + // assert.Equals(t, auth.audiences, []string{ + // "step-certificate-authority", + // "https://127.0.0.1/sign", + // "https://127.0.0.1/1.0/sign", + // }) } } }) diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 8ccd7a4d..e744f560 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -169,7 +169,7 @@ func TestAuthorize(t *testing.T) { (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "foo")) assert.FatalError(t, err) - _a.provisionerIDIndex.Store(validIssuer+":foo", "42") + // _a.provisioners.Store(validIssuer+":foo", "42") cl := jwt.Claims{ Subject: "test.smallstep.com", diff --git a/authority/config_test.go b/authority/config_test.go index 01cea2a1..ca5de829 100644 --- a/authority/config_test.go +++ b/authority/config_test.go @@ -5,6 +5,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/x509util" stepJOSE "github.com/smallstep/cli/jose" @@ -17,13 +18,13 @@ func TestConfigValidate(t *testing.T) { clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) ac := &AuthConfig{ - Provisioners: []*Provisioner{ - { + Provisioners: provisioner.List{ + &provisioner.JWK{ Name: "Max", Type: "JWK", Key: maxjwk, }, - { + &provisioner.JWK{ Name: "step-cli", Type: "JWK", Key: clijwk, @@ -229,13 +230,13 @@ func TestAuthConfigValidate(t *testing.T) { assert.FatalError(t, err) clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) - p := []*Provisioner{ - { + p := provisioner.List{ + &provisioner.JWK{ Name: "Max", Type: "JWK", Key: maxjwk, }, - { + &provisioner.JWK{ Name: "step-cli", Type: "JWK", Key: clijwk, @@ -263,9 +264,9 @@ func TestAuthConfigValidate(t *testing.T) { "fail-invalid-provisioners": func(t *testing.T) AuthConfigValidateTest { return AuthConfigValidateTest{ ac: &AuthConfig{ - Provisioners: []*Provisioner{ - {Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, - {Name: "foo", Key: &jose.JSONWebKey{}}, + Provisioners: provisioner.List{ + &provisioner.JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, + &provisioner.JWK{Name: "foo", Key: &jose.JSONWebKey{}}, }, }, err: errors.New("provisioner type cannot be empty"), @@ -293,7 +294,7 @@ func TestAuthConfigValidate(t *testing.T) { for name, get := range tests { t.Run(name, func(t *testing.T) { tc := get(t) - err := tc.ac.Validate() + err := tc.ac.Validate([]string{}) if err != nil { if assert.NotNil(t, tc.err) { assert.Equals(t, tc.err.Error(), err.Error()) diff --git a/authority/provisioner/collection_test.go b/authority/provisioner/collection_test.go new file mode 100644 index 00000000..2948fd96 --- /dev/null +++ b/authority/provisioner/collection_test.go @@ -0,0 +1,184 @@ +package provisioner + +import ( + "reflect" + "sync" + "testing" + + "github.com/smallstep/assert" + + "github.com/smallstep/cli/jose" +) + +// func Test_newSortedProvisioners(t *testing.T) { +// provisioners := make(List, 20) +// for i := range provisioners { +// provisioners[i] = generateProvisioner(t) +// } + +// ps, err := newSortedProvisioners(provisioners) +// assert.FatalError(t, err) +// prev := "" +// for i, p := range ps { +// if p.uid < prev { +// t.Errorf("%s should be less that %s", p.uid, prev) +// } +// if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID { +// t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID) +// } +// prev = p.uid +// } +// } + +// func Test_provisionerSlice_Find(t *testing.T) { +// trim := func(s string) string { +// return strings.TrimLeft(s, "0") +// } +// provisioners := make([]*Provisioner, 20) +// for i := range provisioners { +// provisioners[i] = generateProvisioner(t) +// } +// ps, err := newSortedProvisioners(provisioners) +// assert.FatalError(t, err) + +// type args struct { +// cursor string +// limit int +// } +// tests := []struct { +// name string +// p provisionerSlice +// args args +// want []*JWK +// want1 string +// }{ +// {"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""}, +// {"0 to 19", ps, args{"", 20}, provisioners[0:20], ""}, +// {"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)}, +// {"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""}, +// {"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)}, +// {"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)}, +// {"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""}, +// {"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""}, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit) +// if !reflect.DeepEqual(got, tt.want) { +// t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want) +// } +// if got1 != tt.want1 { +// t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1) +// } +// }) +// } +// } + +func TestCollection_Load(t *testing.T) { + p, err := generateJWK() + assert.FatalError(t, err) + byID := new(sync.Map) + byID.Store(p.GetID(), p) + byID.Store("string", "a-string") + + type fields struct { + byID *sync.Map + } + type args struct { + id string + } + tests := []struct { + name string + fields fields + args args + want Interface + want1 bool + }{ + {"ok", fields{byID}, args{p.GetID()}, p, true}, + {"fail", fields{byID}, args{"fail"}, nil, false}, + {"invalid", fields{byID}, args{"string"}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Collection{ + byID: tt.fields.byID, + } + got, got1 := c.Load(tt.args.id) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Collection.Load() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestCollection_LoadByToken(t *testing.T) { + p1, err := generateJWK() + assert.FatalError(t, err) + p2, err := generateJWK() + assert.FatalError(t, err) + p3, err := generateOIDC() + assert.FatalError(t, err) + + byID := new(sync.Map) + byID.Store(p1.GetID(), p1) + byID.Store(p2.GetID(), p2) + byID.Store(p3.GetID(), p3) + byID.Store("string", "a-string") + + jwk, err := decryptJSONWebKey(p1.EncryptedKey) + assert.FatalError(t, err) + token, err := generateSimpleToken(p1.Name, testAudiences[0], jwk) + assert.FatalError(t, err) + t1, c1, err := parseToken(token) + assert.FatalError(t, err) + + jwk, err = decryptJSONWebKey(p2.EncryptedKey) + token, err = generateSimpleToken(p2.Name, testAudiences[1], jwk) + assert.FatalError(t, err) + t2, c2, err := parseToken(token) + assert.FatalError(t, err) + + token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keys.Keys[0]) + assert.FatalError(t, err) + t3, c3, err := parseToken(token) + assert.FatalError(t, err) + + type fields struct { + byID *sync.Map + audiences []string + } + type args struct { + token *jose.JSONWebToken + claims *jose.Claims + } + tests := []struct { + name string + fields fields + args args + want Interface + want1 bool + }{ + {"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true}, + {"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true}, + {"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true}, + {"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Collection{ + byID: tt.fields.byID, + audiences: tt.fields.audiences, + } + got, got1 := c.LoadByToken(tt.args.token, tt.args.claims) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Collection.LoadByToken() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Collection.LoadByToken() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index af1ecdaa..3c3a0c07 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -3,11 +3,22 @@ package provisioner import ( "errors" "testing" + "time" "github.com/smallstep/assert" jose "gopkg.in/square/go-jose.v2" ) +var ( + defaultDisableRenewal = false + globalProvisionerClaims = Claims{ + MinTLSDur: &Duration{5 * time.Minute}, + MaxTLSDur: &Duration{24 * time.Hour}, + DefaultTLSDur: &Duration{24 * time.Hour}, + DisableRenewal: &defaultDisableRenewal, + } +) + func TestProvisionerInit(t *testing.T) { type ProvisionerValidateTest struct { p *JWK @@ -39,10 +50,13 @@ func TestProvisionerInit(t *testing.T) { }, } + config := Config{ + Claims: globalProvisionerClaims, + } for name, get := range tests { t.Run(name, func(t *testing.T) { tc := get(t) - err := tc.p.Init(&globalProvisionerClaims) + err := tc.p.Init(config) if err != nil { if assert.NotNil(t, tc.err) { assert.Equals(t, tc.err.Error(), err.Error()) diff --git a/authority/claims_test.go b/authority/provisioner/sign_options_test.go similarity index 99% rename from authority/claims_test.go rename to authority/provisioner/sign_options_test.go index d9c9d768..cc95c52c 100644 --- a/authority/claims_test.go +++ b/authority/provisioner/sign_options_test.go @@ -1,4 +1,6 @@ -package authority +// +build ignore + +package provisioner import ( "crypto/x509/pkix" diff --git a/authority/provisioner/testdata/root_ca.crt b/authority/provisioner/testdata/root_ca.crt new file mode 100644 index 00000000..c802b420 --- /dev/null +++ b/authority/provisioner/testdata/root_ca.crt @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE----- +MIIBhzCCASygAwIBAgIRANJiwPnM38wWznkJGOcIyIYwCgYIKoZIzj0EAwIwITEf +MB0GA1UEAxMWU21hbGxzdGVwIFRlc3QgUm9vdCBDQTAeFw0xODA5MjcxODE4MDla +Fw0yODA5MjQxODE4MDlaMCExHzAdBgNVBAMTFlNtYWxsc3RlcCBUZXN0IFJvb3Qg +Q0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS15w7dx9zPjCnQ7+RlRkvUXQJN +Fjk5Hg5K9nCoiiNQQhcQMw63/pXQxHNsugiMshcN59XJC8195KJPm25nXN8co0Uw +QzAOBgNVHQ8BAf8EBAMCAaYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQU +B2BAXUSPZbFjnY6VzbApV48Tn3owCgYIKoZIzj0EAwIDSQAwRgIhAJRTVmc2xW8c +ESx4oIp2d/OX9KBZzpcNi9fHnnJCS0FXAiEA7OpFb2+b8KBzg1c02x21PS7pHoET +/A8LXNH4M06A7vE= +-----END CERTIFICATE----- diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go new file mode 100644 index 00000000..364555fc --- /dev/null +++ b/authority/provisioner/utils_test.go @@ -0,0 +1,208 @@ +package provisioner + +import ( + "crypto" + "encoding/hex" + "encoding/json" + "time" + + "github.com/smallstep/cli/crypto/randutil" + "github.com/smallstep/cli/jose" + "github.com/smallstep/cli/token" + "github.com/smallstep/cli/token/provision" +) + +var testAudiences = []string{ + "https://ca.smallstep.com/sign", + "https://ca.smallsteomcom/1.0/sign", +} + +func generateJSONWebKey() (*jose.JSONWebKey, error) { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + if err != nil { + return nil, err + } + fp, err := jwk.Thumbprint(crypto.SHA256) + if err != nil { + return nil, err + } + jwk.KeyID = string(hex.EncodeToString(fp)) + return jwk, nil +} + +func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) { + b, err := json.Marshal(jwk) + if err != nil { + return nil, err + } + salt, err := randutil.Salt(jose.PBKDF2SaltSize) + if err != nil { + return nil, err + } + opts := new(jose.EncrypterOptions) + opts.WithContentType(jose.ContentType("jwk+json")) + recipient := jose.Recipient{ + Algorithm: jose.PBES2_HS256_A128KW, + Key: []byte("password"), + PBES2Count: jose.PBKDF2Iterations, + PBES2Salt: salt, + } + encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts) + if err != nil { + return nil, err + } + return encrypter.Encrypt(b) +} + +func decryptJSONWebKey(key string) (*jose.JSONWebKey, error) { + enc, err := jose.ParseEncrypted(key) + if err != nil { + return nil, err + } + b, err := enc.Decrypt([]byte("password")) + if err != nil { + return nil, err + } + jwk := new(jose.JSONWebKey) + if err := json.Unmarshal(b, jwk); err != nil { + return nil, err + } + return jwk, nil +} + +func generateJWK() (*JWK, error) { + name, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + jwk, err := generateJSONWebKey() + if err != nil { + return nil, err + } + jwe, err := encryptJSONWebKey(jwk) + if err != nil { + return nil, err + } + public := jwk.Public() + encrypted, err := jwe.CompactSerialize() + if err != nil { + return nil, err + } + return &JWK{ + Name: name, + Type: "JWK", + Key: &public, + EncryptedKey: encrypted, + audiences: testAudiences, + }, nil +} + +func generateOIDC() (*OIDC, error) { + name, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + clientID, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + issuer, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + jwk, err := generateJSONWebKey() + if err != nil { + return nil, err + } + return &OIDC{ + Name: name, + Type: "OIDC", + ClientID: clientID, + ConfigurationEndpoint: "https://example.com/.well-known/openid-configuration", + configuration: openIDConfiguration{ + Issuer: issuer, + JWKSetURI: "https://example.com/.well-known/jwks", + }, + keyStore: &keyStore{ + keys: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, + expiry: time.Now().Add(24 * time.Hour), + }, + }, nil +} + +func generateCollection(nJWK, nOIDC int) (*Collection, error) { + col := NewCollection(testAudiences) + for i := 0; i < nJWK; i++ { + p, err := generateJWK() + if err != nil { + return nil, err + } + col.Store(p) + } + for i := 0; i < nOIDC; i++ { + p, err := generateOIDC() + if err != nil { + return nil, err + } + col.Store(p) + } + return col, nil +} + +func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { + now := time.Now() + return generateToken("the-sub", []string{"test.smallstep.com"}, jwk.KeyID, iss, aud, "testdata/root_ca.crt", now, now.Add(5*time.Minute), jwk) +} + +func generateToken(sub string, sans []string, kid, iss, aud, root string, notBefore, notAfter time.Time, jwk *jose.JSONWebKey) (string, error) { + // A random jwt id will be used to identify duplicated tokens + jwtID, err := randutil.Hex(64) // 256 bits + if err != nil { + return "", err + } + + tokOptions := []token.Options{ + token.WithJWTID(jwtID), + token.WithKid(kid), + token.WithIssuer(iss), + token.WithAudience(aud), + } + if len(root) > 0 { + tokOptions = append(tokOptions, token.WithRootCA(root)) + } + + // If there are no SANs then add the 'subject' (common-name) as the only SAN. + if len(sans) == 0 { + sans = []string{sub} + } + + tokOptions = append(tokOptions, token.WithSANS(sans)) + if !notBefore.IsZero() || !notAfter.IsZero() { + if notBefore.IsZero() { + notBefore = time.Now() + } + if notAfter.IsZero() { + notAfter = notBefore.Add(token.DefaultValidity) + } + tokOptions = append(tokOptions, token.WithValidity(notBefore, notAfter)) + } + + tok, err := provision.New(sub, tokOptions...) + if err != nil { + return "", err + } + + return tok.SignedString(jwk.Algorithm, jwk.Key) +} + +func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) { + tok, err := jose.ParseSigned(token) + if err != nil { + return nil, nil, err + } + claims := new(jose.Claims) + if err := tok.UnsafeClaimsWithoutVerification(claims); err != nil { + return nil, nil, err + } + return tok, claims, nil +} diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index b982a366..53f2c733 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -1,16 +1,12 @@ package authority import ( - "encoding/json" "net/http" - "reflect" - "strings" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" - "github.com/smallstep/cli/crypto/randutil" - "github.com/smallstep/cli/jose" + "github.com/smallstep/certificates/authority/provisioner" ) func TestGetEncryptedKey(t *testing.T) { @@ -27,7 +23,7 @@ func TestGetEncryptedKey(t *testing.T) { assert.FatalError(t, err) return &ek{ a: a, - kid: c.AuthorityConfig.Provisioners[1].Key.KeyID, + kid: c.AuthorityConfig.Provisioners[1].(*provisioner.JWK).Key.KeyID, } }, "fail-not-found": func(t *testing.T) *ek { @@ -42,19 +38,19 @@ func TestGetEncryptedKey(t *testing.T) { http.StatusNotFound, context{}}, } }, - "fail-invalid-type-found": func(t *testing.T) *ek { - c, err := LoadConfiguration("../ca/testdata/ca.json") - assert.FatalError(t, err) - a, err := New(c) - assert.FatalError(t, err) - a.encryptedKeyIndex.Store("foo", 5) - return &ek{ - a: a, - kid: "foo", - err: &apiError{errors.Errorf("stored value is not a string"), - http.StatusInternalServerError, context{}}, - } - }, + // "fail-invalid-type-found": func(t *testing.T) *ek { + // c, err := LoadConfiguration("../ca/testdata/ca.json") + // assert.FatalError(t, err) + // a, err := New(c) + // assert.FatalError(t, err) + // a.encryptedKeyIndex.Store("foo", 5) + // return &ek{ + // a: a, + // kid: "foo", + // err: &apiError{errors.Errorf("stored value is not a string"), + // http.StatusInternalServerError, context{}}, + // } + // }, } for name, genTestCase := range tests { @@ -75,9 +71,9 @@ func TestGetEncryptedKey(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - val, ok := tc.a.provisionerIDIndex.Load("max:" + tc.kid) + val, ok := tc.a.provisioners.Load("max:" + tc.kid) assert.Fatal(t, ok) - p, ok := val.(*Provisioner) + p, ok := val.(*provisioner.JWK) assert.Fatal(t, ok) assert.Equals(t, p.EncryptedKey, ek) } @@ -126,102 +122,3 @@ func TestGetProvisioners(t *testing.T) { }) } } - -func generateProvisioner(t *testing.T) *Provisioner { - name, err := randutil.Alphanumeric(10) - assert.FatalError(t, err) - // Create a new JWK - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - // Encrypt JWK - salt, err := randutil.Salt(jose.PBKDF2SaltSize) - assert.FatalError(t, err) - b, err := json.Marshal(jwk) - assert.FatalError(t, err) - recipient := jose.Recipient{ - Algorithm: jose.PBES2_HS256_A128KW, - Key: []byte("password"), - PBES2Count: jose.PBKDF2Iterations, - PBES2Salt: salt, - } - opts := new(jose.EncrypterOptions) - opts.WithContentType(jose.ContentType("jwk+json")) - encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts) - assert.FatalError(t, err) - jwe, err := encrypter.Encrypt(b) - assert.FatalError(t, err) - // get public and encrypted keys - public := jwk.Public() - encrypted, err := jwe.CompactSerialize() - assert.FatalError(t, err) - return &Provisioner{ - Name: name, - Type: "JWT", - Key: &public, - EncryptedKey: encrypted, - } -} - -func Test_newSortedProvisioners(t *testing.T) { - provisioners := make([]*Provisioner, 20) - for i := range provisioners { - provisioners[i] = generateProvisioner(t) - } - - ps, err := newSortedProvisioners(provisioners) - assert.FatalError(t, err) - prev := "" - for i, p := range ps { - if p.uid < prev { - t.Errorf("%s should be less that %s", p.uid, prev) - } - if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID { - t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID) - } - prev = p.uid - } -} - -func Test_provisionerSlice_Find(t *testing.T) { - trim := func(s string) string { - return strings.TrimLeft(s, "0") - } - provisioners := make([]*Provisioner, 20) - for i := range provisioners { - provisioners[i] = generateProvisioner(t) - } - ps, err := newSortedProvisioners(provisioners) - assert.FatalError(t, err) - - type args struct { - cursor string - limit int - } - tests := []struct { - name string - p provisionerSlice - args args - want []*Provisioner - want1 string - }{ - {"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""}, - {"0 to 19", ps, args{"", 20}, provisioners[0:20], ""}, - {"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)}, - {"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""}, - {"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)}, - {"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)}, - {"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""}, - {"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want) - } - if got1 != tt.want1 { - t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1) - } - }) - } -} diff --git a/authority/tls_test.go b/authority/tls_test.go index 70b3d7a1..1e553852 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -7,7 +7,6 @@ import ( "crypto/x509/pkix" "encoding/asn1" "fmt" - "net" "net/http" "reflect" "testing" @@ -15,6 +14,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/keys" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/x509util" @@ -52,24 +52,24 @@ func TestSign(t *testing.T) { } nb := time.Now() - signOpts := SignOptions{ + signOpts := provisioner.Options{ NotBefore: nb, NotAfter: nb.Add(time.Minute * 5), } - p := a.config.AuthorityConfig.Provisioners[1] - extraOpts := []interface{}{ - &commonNameClaim{"smallstep test"}, - &dnsNamesClaim{[]string{"test.smallstep.com"}}, - &ipAddressesClaim{[]net.IP{}}, - p, + p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK) + extraOpts := []provisioner.SignOption{ + // &commonNameClaim{"smallstep test"}, + // &dnsNamesClaim{[]string{"test.smallstep.com"}}, + // &ipAddressesClaim{[]net.IP{}}, + // p, } type signTest struct { auth *Authority csr *x509.CertificateRequest - signOpts SignOptions - extraOpts []interface{} + signOpts provisioner.Options + extraOpts []provisioner.SignOption err *apiError } tests := map[string]func(*testing.T) *signTest{ @@ -123,7 +123,7 @@ func TestSign(t *testing.T) { return &signTest{ auth: _a, csr: csr, - extraOpts: []interface{}{p}, + extraOpts: []provisioner.SignOption{p}, signOpts: signOpts, err: &apiError{errors.New("sign: error creating new leaf certificate"), http.StatusInternalServerError, @@ -133,7 +133,7 @@ func TestSign(t *testing.T) { }, "fail provisioner duration claim": func(t *testing.T) *signTest { csr := getCSR(t, priv) - _signOpts := SignOptions{ + _signOpts := provisioner.Options{ NotBefore: nb, NotAfter: nb.Add(time.Hour * 25), } @@ -262,7 +262,7 @@ func TestRenew(t *testing.T) { now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) na1 := now - so := &SignOptions{ + so := &provisioner.Options{ NotBefore: nb1, NotAfter: na1, } @@ -272,7 +272,7 @@ func TestRenew(t *testing.T) { x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0), withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"), - withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].Key.KeyID)) + withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID)) assert.FatalError(t, err) crtBytes, err := leaf.CreateCertificate() assert.FatalError(t, err) @@ -284,7 +284,7 @@ func TestRenew(t *testing.T) { x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0), withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"), - withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID), + withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID), ) assert.FatalError(t, err) crtBytesNoRenew, err := leafNoRenew.CreateCertificate()