From 4c5fec06bf10c1907cffdd7b307667361431bd12 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 7 May 2019 19:07:49 -0700 Subject: [PATCH] Require TenantID in azure, add some tests. --- authority/provisioner/azure.go | 35 ++-- authority/provisioner/azure_test.go | 246 ++++++++++++++++++++++++++++ authority/provisioner/utils_test.go | 123 ++++++++++++++ 3 files changed, 393 insertions(+), 11 deletions(-) create mode 100644 authority/provisioner/azure_test.go diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index dffa21f2..e6ac3359 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -15,8 +15,8 @@ import ( "github.com/smallstep/cli/jose" ) -// azureOIDCDiscoveryURL is the default discovery url for Microsoft Azure tokens. -const azureOIDCDiscoveryURL = "https://login.microsoftonline.com/common/.well-known/openid-configuration" +// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens. +const azureOIDCBaseURL = "https://login.microsoftonline.com" // azureIdentityTokenURL is the URL to get the identity token for an instance. const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F" @@ -33,9 +33,9 @@ type azureConfig struct { identityTokenURL string } -func newAzureConfig() *azureConfig { +func newAzureConfig(tenantID string) *azureConfig { return &azureConfig{ - oidcDiscoveryURL: azureOIDCDiscoveryURL, + oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration", identityTokenURL: azureIdentityTokenURL, } } @@ -77,6 +77,7 @@ type azurePayload struct { type Azure struct { Type string `json:"type"` Name string `json:"name"` + TenantID string `json:"tenantId"` Subscriptions []string `json:"subscriptions"` Audience string `json:"audience,omitempty"` DisableCustomSANs bool `json:"disableCustomSANs"` @@ -90,7 +91,7 @@ type Azure struct { // GetID returns the provisioner unique identifier. func (p *Azure) GetID() string { - return p.Audience + return p.TenantID } // GetTokenID returns the identifier of the token. The default value for Azure @@ -176,16 +177,20 @@ func (p *Azure) Init(config Config) (err error) { return errors.New("provisioner type cannot be empty") case p.Name == "": return errors.New("provisioner name cannot be empty") + case p.TenantID == "": + return errors.New("provisioner tenantId cannot be empty") case p.Audience == "": // use default audience p.Audience = azureDefaultAudience } + // Initialize config + if err := p.assertConfig(); err != nil { + return err + } // Update claims with global ones if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { return err } - // Initialize configuration - p.config = newAzureConfig() // Decode and validate openid-configuration endpoint if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { @@ -209,12 +214,15 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) { if err != nil { return nil, errors.Wrapf(err, "error parsing token") } + if len(jwt.Headers) == 0 { + return nil, errors.New("error parsing token: header is missing") + } var found bool var claims azurePayload keys := p.keyStore.Get(jwt.Headers[0].KeyID) for _, key := range keys { - if err := jwt.Claims(key, &claims); err == nil { + if err := jwt.Claims(key.Public(), &claims); err == nil { found = true break } @@ -225,12 +233,17 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) { if err := claims.ValidateWithLeeway(jose.Expected{ Audience: []string{p.Audience}, - Issuer: strings.Replace(p.oidcConfig.Issuer, "{tenantid}", claims.TenantID, 1), + Issuer: p.oidcConfig.Issuer, Time: time.Now(), }, 1*time.Minute); err != nil { return nil, errors.Wrap(err, "failed to validate payload") } + // Validate TenantID + if claims.TenantID != p.TenantID { + return nil, errors.New("validation failed: invalid tenant id claim (tid)") + } + re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) if len(re) == 0 { return nil, errors.Errorf("error parsing xms_mirid claim: %s", claims.XMSMirID) @@ -247,7 +260,7 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) { } } if !found { - return nil, errors.Errorf("subscription %s is not valid", subscription) + return nil, errors.New("validation failed: invalid subscription id") } } @@ -287,6 +300,6 @@ func (p *Azure) assertConfig() error { if p.config != nil { return nil } - p.config = newAzureConfig() + p.config = newAzureConfig(p.TenantID) return nil } diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go new file mode 100644 index 00000000..c986b5ce --- /dev/null +++ b/authority/provisioner/azure_test.go @@ -0,0 +1,246 @@ +package provisioner + +import ( + "crypto/x509" + "reflect" + "testing" + + "github.com/smallstep/assert" +) + +func TestAzure_Getters(t *testing.T) { + p, err := generateAzure() + assert.FatalError(t, err) + if got := p.GetID(); got != p.TenantID { + t.Errorf("Azure.GetID() = %v, want %v", got, p.TenantID) + } + if got := p.GetName(); got != p.Name { + t.Errorf("Azure.GetName() = %v, want %v", got, p.Name) + } + if got := p.GetType(); got != TypeAzure { + t.Errorf("Azure.GetType() = %v, want %v", got, TypeAzure) + } + kid, key, ok := p.GetEncryptedKey() + if kid != "" || key != "" || ok == true { + t.Errorf("Azure.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", + kid, key, ok, "", "", false) + } +} + +func TestAzure_GetTokenID(t *testing.T) { + type fields struct { + Type string + Name string + DisableCustomSANs bool + DisableTrustOnFirstUse bool + Claims *Claims + claimer *Claimer + config *azureConfig + } + type args struct { + token string + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Azure{ + Type: tt.fields.Type, + Name: tt.fields.Name, + DisableCustomSANs: tt.fields.DisableCustomSANs, + DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, + Claims: tt.fields.Claims, + claimer: tt.fields.claimer, + config: tt.fields.config, + } + got, err := p.GetTokenID(tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("Azure.GetTokenID() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Azure.GetTokenID() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAzure_Init(t *testing.T) { + az, srv, err := generateAzureWithServer() + assert.FatalError(t, err) + defer srv.Close() + + config := Config{ + Claims: globalProvisionerClaims, + } + badClaims := &Claims{ + DefaultTLSDur: &Duration{0}, + } + + type fields struct { + Type string + Name string + TenantID string + DisableCustomSANs bool + DisableTrustOnFirstUse bool + Claims *Claims + } + type args struct { + config Config + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{az.Type, az.Name, az.TenantID, false, false, nil}, args{config}, false}, + {"ok", fields{az.Type, az.Name, az.TenantID, true, false, nil}, args{config}, false}, + {"ok", fields{az.Type, az.Name, az.TenantID, false, true, nil}, args{config}, false}, + {"ok", fields{az.Type, az.Name, az.TenantID, true, true, nil}, args{config}, false}, + {"fail type", fields{"", az.Name, az.TenantID, false, false, nil}, args{config}, true}, + {"fail name", fields{az.Type, "", az.TenantID, false, false, nil}, args{config}, true}, + {"fail tenant id", fields{az.Type, az.Name, "", false, false, nil}, args{config}, true}, + {"fail claims", fields{az.Type, az.Name, az.TenantID, false, false, badClaims}, args{config}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Azure{ + Type: tt.fields.Type, + Name: tt.fields.Name, + TenantID: tt.fields.TenantID, + DisableCustomSANs: tt.fields.DisableCustomSANs, + DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, + Claims: tt.fields.Claims, + config: az.config, + } + if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { + t.Errorf("Azure.Init() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAzure_AuthorizeSign(t *testing.T) { + type fields struct { + Type string + Name string + DisableCustomSANs bool + DisableTrustOnFirstUse bool + Claims *Claims + claimer *Claimer + config *azureConfig + } + type args struct { + token string + } + tests := []struct { + name string + fields fields + args args + want []SignOption + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Azure{ + Type: tt.fields.Type, + Name: tt.fields.Name, + DisableCustomSANs: tt.fields.DisableCustomSANs, + DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, + Claims: tt.fields.Claims, + claimer: tt.fields.claimer, + config: tt.fields.config, + } + got, err := p.AuthorizeSign(tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Azure.AuthorizeSign() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAzure_AuthorizeRenewal(t *testing.T) { + p1, err := generateAzure() + assert.FatalError(t, err) + p2, err := generateAzure() + assert.FatalError(t, err) + + // disable renewal + disable := true + p2.Claims = &Claims{DisableRenewal: &disable} + p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + azure *Azure + args args + wantErr bool + }{ + {"ok", p1, args{nil}, false}, + {"fail", p2, args{nil}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.azure.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("Azure.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAzure_AuthorizeRevoke(t *testing.T) { + type fields struct { + Type string + Name string + DisableCustomSANs bool + DisableTrustOnFirstUse bool + Claims *Claims + claimer *Claimer + config *azureConfig + } + type args struct { + token string + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Azure{ + Type: tt.fields.Type, + Name: tt.fields.Name, + DisableCustomSANs: tt.fields.DisableCustomSANs, + DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, + Claims: tt.fields.Claims, + claimer: tt.fields.claimer, + config: tt.fields.config, + } + if err := p.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr { + t.Errorf("Azure.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 47d42622..94fc7015 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "encoding/json" "encoding/pem" + "fmt" "net/http" "net/http/httptest" "time" @@ -328,6 +329,99 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) { return aws, srv, nil } +func generateAzure() (*Azure, error) { + name, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + tenantID, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + claimer, err := NewClaimer(nil, globalProvisionerClaims) + if err != nil { + return nil, err + } + jwk, err := generateJSONWebKey() + if err != nil { + return nil, err + } + return &Azure{ + Type: "Azure", + Name: name, + TenantID: tenantID, + Claims: &globalProvisionerClaims, + claimer: claimer, + config: newAzureConfig(tenantID), + oidcConfig: openIDConfiguration{ + Issuer: "https://sts.windows.net/" + tenantID + "/", + JWKSetURI: "https://login.microsoftonline.com/common/discovery/keys", + }, + keyStore: &keyStore{ + keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, + expiry: time.Now().Add(24 * time.Hour), + }, + }, nil +} + +func generateAzureWithServer() (*Azure, *httptest.Server, error) { + az, err := generateAzure() + if err != nil { + return nil, nil, err + } + writeJSON := func(w http.ResponseWriter, v interface{}) { + b, err := json.Marshal(v) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(b) + } + getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet { + var ret jose.JSONWebKeySet + for _, k := range ks.Keys { + ret.Keys = append(ret.Keys, k.Public()) + } + return ret + } + issuer := "https://sts.windows.net/" + az.TenantID + "/" + srv := httptest.NewUnstartedServer(nil) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/error": + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + case "/" + az.TenantID + "/.well-known/openid-configuration": + writeJSON(w, openIDConfiguration{Issuer: issuer, JWKSetURI: srv.URL + "/jwks_uri"}) + case "/random": + keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) + w.Header().Add("Cache-Control", "max-age=5") + writeJSON(w, getPublic(keySet)) + case "/private": + writeJSON(w, az.keyStore.keySet) + case "/jwks_uri": + w.Header().Add("Cache-Control", "max-age=5") + writeJSON(w, getPublic(az.keyStore.keySet)) + case "/metadata/identity/oauth2/token": + tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0]) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } else { + writeJSON(w, azureIdentityToken{ + AccessToken: tok, + }) + } + default: + http.NotFound(w, r) + } + }) + srv.Start() + az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration" + az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token" + return az, srv, nil +} + func generateCollection(nJWK, nOIDC int) (*Collection, error) { col := NewCollection(testAudiences) for i := 0; i < nJWK; i++ { @@ -468,6 +562,35 @@ func generateAWSToken(sub, iss, aud, accountID, instanceID, privateIP, region st return jose.Signed(sig).Claims(claims).CompactSerialize() } +func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { + sig, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, + new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), + ) + if err != nil { + return "", err + } + + claims := azurePayload{ + Claims: jose.Claims{ + Subject: sub, + Issuer: iss, + IssuedAt: jose.NewNumericDate(iat), + NotBefore: jose.NewNumericDate(iat), + Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), + Audience: []string{aud}, + }, + AppID: "the-appid", + AppIDAcr: "the-appidacr", + IdentityProvider: "the-idp", + ObjectID: "the-oid", + TenantID: tenantID, + Version: "the-version", + XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine), + } + return jose.Signed(sig).Claims(claims).CompactSerialize() +} + func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) { tok, err := jose.ParseSigned(token) if err != nil {