Use on GCP audiences with the format https://<ca-url>#<provisioner-type>/<provisioner-name>

Fixes smallstep/step#156
This commit is contained in:
Mariano Cano 2019-06-03 17:19:44 -07:00
parent a54bf925eb
commit 0a756ce9d0
5 changed files with 104 additions and 23 deletions

View file

@ -66,8 +66,21 @@ func (c *Collection) Load(id string) (Interface, bool) {
// LoadByToken parses the token claims and loads the provisioner associated. // LoadByToken parses the token claims and loads the provisioner associated.
func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) { func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) {
var audiences []string
// Get all audiences with the given fragment
fragment := extractFragment(claims.Audience)
if fragment == "" {
audiences = c.audiences.All()
} else {
audiences = c.audiences.WithFragment(fragment).All()
}
// match with server audiences // match with server audiences
if matchesAudience(claims.Audience, c.audiences.All()) { if matchesAudience(claims.Audience, audiences) {
// Use fragment to get audiences (GCP)
if fragment != "" {
return c.Load(fragment)
}
// If matches with stored audiences it will be a JWT token (default), and // If matches with stored audiences it will be a JWT token (default), and
// the id would be <issuer>:<kid>. // the id would be <issuer>:<kid>.
return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID) return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID)
@ -234,3 +247,13 @@ func stripPort(rawurl string) string {
u.Host = u.Hostname() u.Host = u.Hostname()
return u.String() return u.String()
} }
// extractFragment extracts the
func extractFragment(audience []string) string {
for _, s := range audience {
if u, err := url.Parse(s); err == nil && u.Fragment != "" {
return u.Fragment
}
}
return ""
}

View file

@ -77,12 +77,13 @@ type GCP struct {
claimer *Claimer claimer *Claimer
config *gcpConfig config *gcpConfig
keyStore *keyStore keyStore *keyStore
audiences Audiences
} }
// GetID returns the provisioner unique identifier. The name should uniquely // GetID returns the provisioner unique identifier. The name should uniquely
// identify any GCP provisioner. // identify any GCP provisioner.
func (p *GCP) GetID() string { func (p *GCP) GetID() string {
return "gcp:" + p.Name return "gcp/" + p.Name
} }
// GetTokenID returns the identifier of the token. The default value for GCP the // GetTokenID returns the identifier of the token. The default value for GCP the
@ -130,20 +131,25 @@ func (p *GCP) GetEncryptedKey() (kid string, key string, ok bool) {
} }
// GetIdentityURL returns the url that generates the GCP token. // GetIdentityURL returns the url that generates the GCP token.
func (p *GCP) GetIdentityURL() string { func (p *GCP) GetIdentityURL(audience string) string {
// Initialize config if required // Initialize config if required
p.assertConfig() p.assertConfig()
q := url.Values{} q := url.Values{}
q.Add("audience", p.GetID()) q.Add("audience", audience)
q.Add("format", "full") q.Add("format", "full")
q.Add("licenses", "FALSE") q.Add("licenses", "FALSE")
return fmt.Sprintf("%s?%s", p.config.IdentityURL, q.Encode()) return fmt.Sprintf("%s?%s", p.config.IdentityURL, q.Encode())
} }
// GetIdentityToken does an HTTP request to the identity url. // GetIdentityToken does an HTTP request to the identity url.
func (p *GCP) GetIdentityToken() (string, error) { func (p *GCP) GetIdentityToken(caURL string) (string, error) {
req, err := http.NewRequest("GET", p.GetIdentityURL(), http.NoBody) audience, err := generateSignAudience(caURL, p.GetID())
if err != nil {
return "", err
}
req, err := http.NewRequest("GET", p.GetIdentityURL(audience), http.NoBody)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error creating identity request") return "", errors.Wrap(err, "error creating identity request")
} }
@ -183,6 +189,8 @@ func (p *GCP) Init(config Config) error {
if err != nil { if err != nil {
return err return err
} }
p.audiences = config.Audiences.WithFragment(p.GetID())
return nil return nil
} }
@ -265,12 +273,16 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
// more than a few minutes. // more than a few minutes.
if err = claims.ValidateWithLeeway(jose.Expected{ if err = claims.ValidateWithLeeway(jose.Expected{
Issuer: "https://accounts.google.com", Issuer: "https://accounts.google.com",
Audience: []string{p.GetID()},
Time: time.Now().UTC(), Time: time.Now().UTC(),
}, time.Minute); err != nil { }, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token") return nil, errors.Wrapf(err, "invalid token")
} }
// validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences.Sign) {
return nil, errors.New("invalid token: invalid audience claim (aud)")
}
// validate subject (service account) // validate subject (service account)
if len(p.ServiceAccounts) > 0 { if len(p.ServiceAccounts) > 0 {
var found bool var found bool

View file

@ -18,9 +18,9 @@ import (
func TestGCP_Getters(t *testing.T) { func TestGCP_Getters(t *testing.T) {
p, err := generateGCP() p, err := generateGCP()
assert.FatalError(t, err) assert.FatalError(t, err)
aud := "gcp:" + p.Name id := "gcp/" + p.Name
if got := p.GetID(); got != aud { if got := p.GetID(); got != id {
t.Errorf("GCP.GetID() = %v, want %v", got, aud) t.Errorf("GCP.GetID() = %v, want %v", got, id)
} }
if got := p.GetName(); got != p.Name { if got := p.GetName(); got != p.Name {
t.Errorf("GCP.GetName() = %v, want %v", got, p.Name) t.Errorf("GCP.GetName() = %v, want %v", got, p.Name)
@ -33,8 +33,10 @@ func TestGCP_Getters(t *testing.T) {
t.Errorf("GCP.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", t.Errorf("GCP.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false) kid, key, ok, "", "", false)
} }
expected := fmt.Sprintf("http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=%s&format=full&licenses=FALSE", url.QueryEscape(p.GetID()))
if got := p.GetIdentityURL(); got != expected { aud := "https://ca.smallstep.com/1.0/sign#" + url.QueryEscape(id)
expected := fmt.Sprintf("http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=%s&format=full&licenses=FALSE", url.QueryEscape(aud))
if got := p.GetIdentityURL(aud); got != expected {
t.Errorf("GCP.GetIdentityURL() = %v, want %v", got, expected) t.Errorf("GCP.GetIdentityURL() = %v, want %v", got, expected)
} }
} }
@ -50,7 +52,7 @@ func TestGCP_GetTokenID(t *testing.T) {
now := time.Now() now := time.Now()
t1, err := generateGCPToken(p1.ServiceAccounts[0], t1, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", "gcp:name", "https://accounts.google.com", "gcp/name",
"instance-id", "instance-name", "project-id", "zone", "instance-id", "instance-name", "project-id", "zone",
now, &p1.keyStore.keySet.Keys[0]) now, &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
@ -60,7 +62,7 @@ func TestGCP_GetTokenID(t *testing.T) {
now, &p2.keyStore.keySet.Keys[0]) now, &p2.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
sum := sha256.Sum256([]byte("gcp:name.instance-id")) sum := sha256.Sum256([]byte("gcp/name.instance-id"))
want1 := strings.ToLower(hex.EncodeToString(sum[:])) want1 := strings.ToLower(hex.EncodeToString(sum[:]))
sum = sha256.Sum256([]byte(t2)) sum = sha256.Sum256([]byte(t2))
want2 := strings.ToLower(hex.EncodeToString(sum[:])) want2 := strings.ToLower(hex.EncodeToString(sum[:]))
@ -114,22 +116,27 @@ func TestGCP_GetIdentityToken(t *testing.T) {
})) }))
defer srv.Close() defer srv.Close()
type args struct {
caURL string
}
tests := []struct { tests := []struct {
name string name string
gcp *GCP gcp *GCP
args args
identityURL string identityURL string
want string want string
wantErr bool wantErr bool
}{ }{
{"ok", p1, srv.URL, t1, false}, {"ok", p1, args{"https://ca"}, srv.URL, t1, false},
{"fail request", p1, srv.URL + "/bad-request", "", true}, {"fail ca url", p1, args{"://ca"}, srv.URL, "", true},
{"fail url", p1, "://ca.smallstep.com", "", true}, {"fail request", p1, args{"https://ca"}, srv.URL + "/bad-request", "", true},
{"fail connect", p1, "foobarzar", "", true}, {"fail url", p1, args{"https://ca"}, "://ca.smallstep.com", "", true},
{"fail connect", p1, args{"https://ca"}, "foobarzar", "", true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tt.gcp.config.IdentityURL = tt.identityURL tt.gcp.config.IdentityURL = tt.identityURL
got, err := tt.gcp.GetIdentityToken() got, err := tt.gcp.GetIdentityToken(tt.args.caURL)
t.Log(err) t.Log(err)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("GCP.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)

View file

@ -3,6 +3,7 @@ package provisioner
import ( import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"net/url"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -28,10 +29,44 @@ type Audiences struct {
} }
// All returns all supported audiences across all request types in one list. // All returns all supported audiences across all request types in one list.
func (a *Audiences) All() []string { func (a Audiences) All() []string {
return append(a.Sign, a.Revoke...) return append(a.Sign, a.Revoke...)
} }
// WithFragment returns a copy of audiences where the url audiences contains the
// given fragment.
func (a Audiences) WithFragment(fragment string) Audiences {
ret := Audiences{
Sign: make([]string, len(a.Sign)),
Revoke: make([]string, len(a.Revoke)),
}
for i, s := range a.Sign {
if u, err := url.Parse(s); err == nil {
ret.Sign[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
} else {
ret.Sign[i] = s
}
}
for i, s := range a.Revoke {
if u, err := url.Parse(s); err == nil {
ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
} else {
ret.Revoke[i] = s
}
}
return ret
}
// generateSignAudience generates a sign audience with the format
// https://<ca-url>/1.0/sign#provisionerID
func generateSignAudience(caURL string, provisionerID string) (string, error) {
u, err := url.Parse(caURL)
if err != nil {
return "", errors.Wrapf(err, "error parsing %s", caURL)
}
return u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: provisionerID}).String(), nil
}
// Type indicates the provisioner Type. // Type indicates the provisioner Type.
type Type int type Type int

View file

@ -229,6 +229,7 @@ func generateGCP() (*GCP, error) {
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
audiences: testAudiences.WithFragment("gcp/" + name),
}, nil }, nil
} }
@ -492,7 +493,10 @@ func generateGCPToken(sub, iss, aud, instanceID, instanceName, projectID, zone s
if err != nil { if err != nil {
return "", err return "", err
} }
aud, err = generateSignAudience("https://ca.smallstep.com", aud)
if err != nil {
return "", err
}
claims := gcpPayload{ claims := gcpPayload{
Claims: jose.Claims{ Claims: jose.Claims{
Subject: sub, Subject: sub,