diff --git a/authority/provisioner/claims.go b/authority/provisioner/claims.go index 09317529..c1ba9c5f 100644 --- a/authority/provisioner/claims.go +++ b/authority/provisioner/claims.go @@ -22,8 +22,7 @@ func (pc *Claims) Init(global *Claims) (*Claims, error) { pc = &Claims{} } pc.globalClaims = global - err := pc.Validate() - return pc, err + return pc, pc.Validate() } // DefaultTLSCertDuration returns the default TLS cert duration for the diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 3c8d07a8..4f4c27cf 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -101,10 +101,6 @@ func (p *JWK) Authorize(token string) ([]SignOption, error) { } dnsNames, ips := x509util.SplitSANs(claims.SANs) - if err != nil { - return nil, err - } - return []SignOption{ commonNameValidator(claims.Subject), dnsNamesValidator(dnsNames), diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index 3c3a0c07..b4c90ee3 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -1,12 +1,15 @@ package provisioner import ( + "crypto/x509" "errors" + "fmt" + "strings" "testing" "time" "github.com/smallstep/assert" - jose "gopkg.in/square/go-jose.v2" + "github.com/smallstep/cli/jose" ) var ( @@ -19,7 +22,32 @@ var ( } ) -func TestProvisionerInit(t *testing.T) { +func TestJWK_Getters(t *testing.T) { + p, err := generateJWK() + assert.FatalError(t, err) + if got := p.GetID(); got != p.Name+":"+p.Key.KeyID { + t.Errorf("JWK.GetID() = %v, want %v:%v", got, p.Name, p.Key.KeyID) + } + if got := p.GetName(); got != p.Name { + t.Errorf("JWK.GetName() = %v, want %v", got, p.Name) + } + if got := p.GetType(); got != TypeJWK { + t.Errorf("JWK.GetType() = %v, want %v", got, TypeJWK) + } + kid, key, ok := p.GetEncryptedKey() + if kid != p.Key.KeyID || key != p.EncryptedKey || ok == false { + t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", + kid, key, ok, p.Key.KeyID, p.EncryptedKey, true) + } + p.EncryptedKey = "" + kid, key, ok = p.GetEncryptedKey() + if kid != p.Key.KeyID || key != "" || ok == true { + t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", + kid, key, ok, p.Key.KeyID, "", false) + } +} + +func TestJWK_Init(t *testing.T) { type ProvisionerValidateTest struct { p *JWK err error @@ -45,13 +73,14 @@ func TestProvisionerInit(t *testing.T) { }, "ok": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, + p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences}, } }, } config := Config{ - Claims: globalProvisionerClaims, + Claims: globalProvisionerClaims, + Audiences: testAudiences, } for name, get := range tests { t.Run(name, func(t *testing.T) { @@ -67,3 +96,143 @@ func TestProvisionerInit(t *testing.T) { }) } } + +func TestJWK_Authorize(t *testing.T) { + p1, err := generateJWK() + assert.FatalError(t, err) + p2, err := generateJWK() + assert.FatalError(t, err) + + key1, err := decryptJSONWebKey(p1.EncryptedKey) + assert.FatalError(t, err) + key2, err := decryptJSONWebKey(p2.EncryptedKey) + assert.FatalError(t, err) + + t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1) + assert.FatalError(t, err) + t2, err := generateSimpleToken(p2.Name, testAudiences[1], key2) + assert.FatalError(t, err) + t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences[0], []string{}, key1) + assert.FatalError(t, err) + + // Invalid tokens + parts := strings.Split(t1, ".") + // invalid token + failTok := "foo." + parts[1] + "." + parts[2] + // invalid claims + failClaims := parts[0] + ".foo." + parts[1] + // invalid issuer + failIss, err := generateSimpleToken("foobar", testAudiences[0], key1) + assert.FatalError(t, err) + // invalid audience + failAud, err := generateSimpleToken(p1.Name, "foobar", key1) + assert.FatalError(t, err) + // invalid signature + failSig := t1[0 : len(t1)-2] + // no subject + failSub, err := generateToken("", p1.Name, testAudiences[0], []string{"test.smallstep.com"}, key1) + assert.FatalError(t, err) + + // Remove encrypted key for p2 + p2.EncryptedKey = "" + + type args struct { + token string + } + tests := []struct { + name string + prov *JWK + args args + wantErr bool + }{ + {"ok", p1, args{t1}, false}, + {"ok-no-encrypted-key", p2, args{t2}, false}, + {"ok-no-sans", p1, args{t3}, false}, + {"fail-token", p1, args{failTok}, true}, + {"fail-claims", p1, args{failClaims}, true}, + {"fail-issuer", p1, args{failIss}, true}, + {"fail-audience", p1, args{failAud}, true}, + {"fail-signature", p1, args{failSig}, true}, + {"fail-subject", p1, args{failSub}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.prov.Authorize(tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("JWK.Authorize() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + assert.Nil(t, got) + } else { + assert.NotNil(t, got) + assert.Len(t, 6, got) + } + }) + } +} + +func TestJWK_AuthorizeRenewal(t *testing.T) { + p1, err := generateJWK() + assert.FatalError(t, err) + p2, err := generateJWK() + assert.FatalError(t, err) + + fmt.Printf("%#v\n", *p1.Claims.DisableRenewal) + // disable renewal + disable := true + p2.Claims = &Claims{ + globalClaims: &globalProvisionerClaims, + DisableRenewal: &disable, + } + + fmt.Printf("%#v\n", *p1.Claims.DisableRenewal) + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + prov *JWK + args args + wantErr bool + }{ + {"ok", p1, args{nil}, false}, + {"fail", p2, args{nil}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("JWK.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestJWK_AuthorizeRevoke(t *testing.T) { + p1, err := generateJWK() + assert.FatalError(t, err) + key1, err := decryptJSONWebKey(p1.EncryptedKey) + assert.FatalError(t, err) + t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1) + assert.FatalError(t, err) + + type args struct { + token string + } + tests := []struct { + name string + prov *JWK + args args + wantErr bool + }{ + {"disabled", p1, args{t1}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr { + t.Errorf("JWK.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 458f8111..74684d2d 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -117,6 +117,7 @@ func generateJWK() (*JWK, error) { Type: "JWK", Key: &public, EncryptedKey: encrypted, + Claims: &globalProvisionerClaims, audiences: testAudiences, }, nil } @@ -143,6 +144,7 @@ func generateOIDC() (*OIDC, error) { Type: "OIDC", ClientID: clientID, ConfigurationEndpoint: "https://example.com/.well-known/openid-configuration", + Claims: &globalProvisionerClaims, configuration: openIDConfiguration{ Issuer: issuer, JWKSetURI: "https://example.com/.well-known/jwks", @@ -174,11 +176,43 @@ func generateCollection(nJWK, nOIDC int) (*Collection, error) { } func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { - now := time.Now() - return generateToken("the-sub", []string{"test.smallstep.com"}, jwk.KeyID, iss, aud, "testdata/root_ca.crt", now, now.Add(5*time.Minute), jwk) + return generateToken("subject", iss, aud, []string{"test.smallstep.com"}, jwk) + // return generateToken("the-sub", []string{"test.smallstep.com"}, jwk.KeyID, iss, aud, "testdata/root_ca.crt", now, now.Add(5*time.Minute), jwk) } -func generateToken(sub string, sans []string, kid, iss, aud, root string, notBefore, notAfter time.Time, jwk *jose.JSONWebKey) (string, error) { +func generateToken(sub, iss, aud string, sans []string, jwk *jose.JSONWebKey) (string, error) { + sig, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, + new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), + ) + if err != nil { + return "", err + } + + id, err := randutil.ASCII(64) + if err != nil { + return "", err + } + + now := time.Now() + claims := struct { + jose.Claims + SANS []string `json:"sans"` + }{ + Claims: jose.Claims{ + ID: id, + Subject: sub, + Issuer: iss, + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + Audience: []string{aud}, + }, + SANS: sans, + } + return jose.Signed(sig).Claims(claims).CompactSerialize() +} + +func generateToken2(sub string, sans []string, kid, iss, aud, root string, notBefore, notAfter time.Time, jwk *jose.JSONWebKey) (string, error) { // A random jwt id will be used to identify duplicated tokens jwtID, err := randutil.Hex(64) // 256 bits if err != nil {