diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index c3c6518c..d0525ad3 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -66,8 +66,21 @@ func (c *Collection) Load(id string) (Interface, bool) { // LoadByToken parses the token claims and loads the provisioner associated. func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) { + var audiences []string + // Get all audiences with the given fragment + fragment := extractFragment(claims.Audience) + if fragment == "" { + audiences = c.audiences.All() + } else { + audiences = c.audiences.WithFragment(fragment).All() + } + // match with server audiences - if matchesAudience(claims.Audience, c.audiences.All()) { + if matchesAudience(claims.Audience, audiences) { + // Use fragment to get audiences (GCP) + if fragment != "" { + return c.Load(fragment) + } // If matches with stored audiences it will be a JWT token (default), and // the id would be :. return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID) @@ -234,3 +247,13 @@ func stripPort(rawurl string) string { u.Host = u.Hostname() return u.String() } + +// extractFragment extracts the +func extractFragment(audience []string) string { + for _, s := range audience { + if u, err := url.Parse(s); err == nil && u.Fragment != "" { + return u.Fragment + } + } + return "" +} diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index 8ee3d86c..953ed4c3 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -77,12 +77,13 @@ type GCP struct { claimer *Claimer config *gcpConfig keyStore *keyStore + audiences Audiences } // GetID returns the provisioner unique identifier. The name should uniquely // identify any GCP provisioner. func (p *GCP) GetID() string { - return "gcp:" + p.Name + return "gcp/" + p.Name } // GetTokenID returns the identifier of the token. The default value for GCP the @@ -130,20 +131,25 @@ func (p *GCP) GetEncryptedKey() (kid string, key string, ok bool) { } // GetIdentityURL returns the url that generates the GCP token. -func (p *GCP) GetIdentityURL() string { +func (p *GCP) GetIdentityURL(audience string) string { // Initialize config if required p.assertConfig() q := url.Values{} - q.Add("audience", p.GetID()) + q.Add("audience", audience) q.Add("format", "full") q.Add("licenses", "FALSE") return fmt.Sprintf("%s?%s", p.config.IdentityURL, q.Encode()) } // GetIdentityToken does an HTTP request to the identity url. -func (p *GCP) GetIdentityToken() (string, error) { - req, err := http.NewRequest("GET", p.GetIdentityURL(), http.NoBody) +func (p *GCP) GetIdentityToken(caURL string) (string, error) { + audience, err := generateSignAudience(caURL, p.GetID()) + if err != nil { + return "", err + } + + req, err := http.NewRequest("GET", p.GetIdentityURL(audience), http.NoBody) if err != nil { return "", errors.Wrap(err, "error creating identity request") } @@ -183,6 +189,8 @@ func (p *GCP) Init(config Config) error { if err != nil { return err } + + p.audiences = config.Audiences.WithFragment(p.GetID()) return nil } @@ -264,13 +272,17 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { // According to "rfc7519 JSON Web Token" acceptable skew should be no // more than a few minutes. if err = claims.ValidateWithLeeway(jose.Expected{ - Issuer: "https://accounts.google.com", - Audience: []string{p.GetID()}, - Time: time.Now().UTC(), + Issuer: "https://accounts.google.com", + Time: time.Now().UTC(), }, time.Minute); err != nil { return nil, errors.Wrapf(err, "invalid token") } + // validate audiences with the defaults + if !matchesAudience(claims.Audience, p.audiences.Sign) { + return nil, errors.New("invalid token: invalid audience claim (aud)") + } + // validate subject (service account) if len(p.ServiceAccounts) > 0 { var found bool diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 75eac9bf..c4a7ac24 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -18,9 +18,9 @@ import ( func TestGCP_Getters(t *testing.T) { p, err := generateGCP() assert.FatalError(t, err) - aud := "gcp:" + p.Name - if got := p.GetID(); got != aud { - t.Errorf("GCP.GetID() = %v, want %v", got, aud) + id := "gcp/" + p.Name + if got := p.GetID(); got != id { + t.Errorf("GCP.GetID() = %v, want %v", got, id) } if got := p.GetName(); got != p.Name { t.Errorf("GCP.GetName() = %v, want %v", got, p.Name) @@ -33,8 +33,10 @@ func TestGCP_Getters(t *testing.T) { t.Errorf("GCP.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) } - expected := fmt.Sprintf("http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=%s&format=full&licenses=FALSE", url.QueryEscape(p.GetID())) - if got := p.GetIdentityURL(); got != expected { + + aud := "https://ca.smallstep.com/1.0/sign#" + url.QueryEscape(id) + expected := fmt.Sprintf("http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=%s&format=full&licenses=FALSE", url.QueryEscape(aud)) + if got := p.GetIdentityURL(aud); got != expected { t.Errorf("GCP.GetIdentityURL() = %v, want %v", got, expected) } } @@ -50,7 +52,7 @@ func TestGCP_GetTokenID(t *testing.T) { now := time.Now() t1, err := generateGCPToken(p1.ServiceAccounts[0], - "https://accounts.google.com", "gcp:name", + "https://accounts.google.com", "gcp/name", "instance-id", "instance-name", "project-id", "zone", now, &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) @@ -60,7 +62,7 @@ func TestGCP_GetTokenID(t *testing.T) { now, &p2.keyStore.keySet.Keys[0]) assert.FatalError(t, err) - sum := sha256.Sum256([]byte("gcp:name.instance-id")) + sum := sha256.Sum256([]byte("gcp/name.instance-id")) want1 := strings.ToLower(hex.EncodeToString(sum[:])) sum = sha256.Sum256([]byte(t2)) want2 := strings.ToLower(hex.EncodeToString(sum[:])) @@ -114,22 +116,27 @@ func TestGCP_GetIdentityToken(t *testing.T) { })) defer srv.Close() + type args struct { + caURL string + } tests := []struct { name string gcp *GCP + args args identityURL string want string wantErr bool }{ - {"ok", p1, srv.URL, t1, false}, - {"fail request", p1, srv.URL + "/bad-request", "", true}, - {"fail url", p1, "://ca.smallstep.com", "", true}, - {"fail connect", p1, "foobarzar", "", true}, + {"ok", p1, args{"https://ca"}, srv.URL, t1, false}, + {"fail ca url", p1, args{"://ca"}, srv.URL, "", true}, + {"fail request", p1, args{"https://ca"}, srv.URL + "/bad-request", "", true}, + {"fail url", p1, args{"https://ca"}, "://ca.smallstep.com", "", true}, + {"fail connect", p1, args{"https://ca"}, "foobarzar", "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.gcp.config.IdentityURL = tt.identityURL - got, err := tt.gcp.GetIdentityToken() + got, err := tt.gcp.GetIdentityToken(tt.args.caURL) t.Log(err) if (err != nil) != tt.wantErr { t.Errorf("GCP.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr) diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 8dc95586..29ebc902 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -3,6 +3,7 @@ package provisioner import ( "crypto/x509" "encoding/json" + "net/url" "strings" "github.com/pkg/errors" @@ -28,10 +29,44 @@ type Audiences struct { } // All returns all supported audiences across all request types in one list. -func (a *Audiences) All() []string { +func (a Audiences) All() []string { return append(a.Sign, a.Revoke...) } +// WithFragment returns a copy of audiences where the url audiences contains the +// given fragment. +func (a Audiences) WithFragment(fragment string) Audiences { + ret := Audiences{ + Sign: make([]string, len(a.Sign)), + Revoke: make([]string, len(a.Revoke)), + } + for i, s := range a.Sign { + if u, err := url.Parse(s); err == nil { + ret.Sign[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() + } else { + ret.Sign[i] = s + } + } + for i, s := range a.Revoke { + if u, err := url.Parse(s); err == nil { + ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() + } else { + ret.Revoke[i] = s + } + } + return ret +} + +// generateSignAudience generates a sign audience with the format +// https:///1.0/sign#provisionerID +func generateSignAudience(caURL string, provisionerID string) (string, error) { + u, err := url.Parse(caURL) + if err != nil { + return "", errors.Wrapf(err, "error parsing %s", caURL) + } + return u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: provisionerID}).String(), nil +} + // Type indicates the provisioner Type. type Type int diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 23175677..d89cbc5d 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -229,6 +229,7 @@ func generateGCP() (*GCP, error) { keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, + audiences: testAudiences.WithFragment("gcp/" + name), }, nil } @@ -492,7 +493,10 @@ func generateGCPToken(sub, iss, aud, instanceID, instanceName, projectID, zone s if err != nil { return "", err } - + aud, err = generateSignAudience("https://ca.smallstep.com", aud) + if err != nil { + return "", err + } claims := gcpPayload{ Claims: jose.Claims{ Subject: sub,