Require TenantID in azure, add some tests.

This commit is contained in:
Mariano Cano 2019-05-07 19:07:49 -07:00
parent 12937c6b75
commit 4c5fec06bf
3 changed files with 393 additions and 11 deletions

View file

@ -15,8 +15,8 @@ import (
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
// azureOIDCDiscoveryURL is the default discovery url for Microsoft Azure tokens. // azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
const azureOIDCDiscoveryURL = "https://login.microsoftonline.com/common/.well-known/openid-configuration" const azureOIDCBaseURL = "https://login.microsoftonline.com"
// azureIdentityTokenURL is the URL to get the identity token for an instance. // azureIdentityTokenURL is the URL to get the identity token for an instance.
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F" const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F"
@ -33,9 +33,9 @@ type azureConfig struct {
identityTokenURL string identityTokenURL string
} }
func newAzureConfig() *azureConfig { func newAzureConfig(tenantID string) *azureConfig {
return &azureConfig{ return &azureConfig{
oidcDiscoveryURL: azureOIDCDiscoveryURL, oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration",
identityTokenURL: azureIdentityTokenURL, identityTokenURL: azureIdentityTokenURL,
} }
} }
@ -77,6 +77,7 @@ type azurePayload struct {
type Azure struct { type Azure struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
TenantID string `json:"tenantId"`
Subscriptions []string `json:"subscriptions"` Subscriptions []string `json:"subscriptions"`
Audience string `json:"audience,omitempty"` Audience string `json:"audience,omitempty"`
DisableCustomSANs bool `json:"disableCustomSANs"` DisableCustomSANs bool `json:"disableCustomSANs"`
@ -90,7 +91,7 @@ type Azure struct {
// GetID returns the provisioner unique identifier. // GetID returns the provisioner unique identifier.
func (p *Azure) GetID() string { func (p *Azure) GetID() string {
return p.Audience return p.TenantID
} }
// GetTokenID returns the identifier of the token. The default value for Azure // GetTokenID returns the identifier of the token. The default value for Azure
@ -176,16 +177,20 @@ func (p *Azure) Init(config Config) (err error) {
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
case p.Name == "": case p.Name == "":
return errors.New("provisioner name cannot be empty") return errors.New("provisioner name cannot be empty")
case p.TenantID == "":
return errors.New("provisioner tenantId cannot be empty")
case p.Audience == "": // use default audience case p.Audience == "": // use default audience
p.Audience = azureDefaultAudience p.Audience = azureDefaultAudience
} }
// Initialize config
if err := p.assertConfig(); err != nil {
return err
}
// Update claims with global ones // Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err return err
} }
// Initialize configuration
p.config = newAzureConfig()
// Decode and validate openid-configuration endpoint // Decode and validate openid-configuration endpoint
if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
@ -209,12 +214,15 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error parsing token") return nil, errors.Wrapf(err, "error parsing token")
} }
if len(jwt.Headers) == 0 {
return nil, errors.New("error parsing token: header is missing")
}
var found bool var found bool
var claims azurePayload var claims azurePayload
keys := p.keyStore.Get(jwt.Headers[0].KeyID) keys := p.keyStore.Get(jwt.Headers[0].KeyID)
for _, key := range keys { for _, key := range keys {
if err := jwt.Claims(key, &claims); err == nil { if err := jwt.Claims(key.Public(), &claims); err == nil {
found = true found = true
break break
} }
@ -225,12 +233,17 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
if err := claims.ValidateWithLeeway(jose.Expected{ if err := claims.ValidateWithLeeway(jose.Expected{
Audience: []string{p.Audience}, Audience: []string{p.Audience},
Issuer: strings.Replace(p.oidcConfig.Issuer, "{tenantid}", claims.TenantID, 1), Issuer: p.oidcConfig.Issuer,
Time: time.Now(), Time: time.Now(),
}, 1*time.Minute); err != nil { }, 1*time.Minute); err != nil {
return nil, errors.Wrap(err, "failed to validate payload") return nil, errors.Wrap(err, "failed to validate payload")
} }
// Validate TenantID
if claims.TenantID != p.TenantID {
return nil, errors.New("validation failed: invalid tenant id claim (tid)")
}
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
if len(re) == 0 { if len(re) == 0 {
return nil, errors.Errorf("error parsing xms_mirid claim: %s", claims.XMSMirID) return nil, errors.Errorf("error parsing xms_mirid claim: %s", claims.XMSMirID)
@ -247,7 +260,7 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
} }
} }
if !found { if !found {
return nil, errors.Errorf("subscription %s is not valid", subscription) return nil, errors.New("validation failed: invalid subscription id")
} }
} }
@ -287,6 +300,6 @@ func (p *Azure) assertConfig() error {
if p.config != nil { if p.config != nil {
return nil return nil
} }
p.config = newAzureConfig() p.config = newAzureConfig(p.TenantID)
return nil return nil
} }

View file

@ -0,0 +1,246 @@
package provisioner
import (
"crypto/x509"
"reflect"
"testing"
"github.com/smallstep/assert"
)
func TestAzure_Getters(t *testing.T) {
p, err := generateAzure()
assert.FatalError(t, err)
if got := p.GetID(); got != p.TenantID {
t.Errorf("Azure.GetID() = %v, want %v", got, p.TenantID)
}
if got := p.GetName(); got != p.Name {
t.Errorf("Azure.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeAzure {
t.Errorf("Azure.GetType() = %v, want %v", got, TypeAzure)
}
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("Azure.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false)
}
}
func TestAzure_GetTokenID(t *testing.T) {
type fields struct {
Type string
Name string
DisableCustomSANs bool
DisableTrustOnFirstUse bool
Claims *Claims
claimer *Claimer
config *azureConfig
}
type args struct {
token string
}
tests := []struct {
name string
fields fields
args args
want string
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Azure{
Type: tt.fields.Type,
Name: tt.fields.Name,
DisableCustomSANs: tt.fields.DisableCustomSANs,
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
Claims: tt.fields.Claims,
claimer: tt.fields.claimer,
config: tt.fields.config,
}
got, err := p.GetTokenID(tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("Azure.GetTokenID() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Azure.GetTokenID() = %v, want %v", got, tt.want)
}
})
}
}
func TestAzure_Init(t *testing.T) {
az, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
config := Config{
Claims: globalProvisionerClaims,
}
badClaims := &Claims{
DefaultTLSDur: &Duration{0},
}
type fields struct {
Type string
Name string
TenantID string
DisableCustomSANs bool
DisableTrustOnFirstUse bool
Claims *Claims
}
type args struct {
config Config
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{az.Type, az.Name, az.TenantID, false, false, nil}, args{config}, false},
{"ok", fields{az.Type, az.Name, az.TenantID, true, false, nil}, args{config}, false},
{"ok", fields{az.Type, az.Name, az.TenantID, false, true, nil}, args{config}, false},
{"ok", fields{az.Type, az.Name, az.TenantID, true, true, nil}, args{config}, false},
{"fail type", fields{"", az.Name, az.TenantID, false, false, nil}, args{config}, true},
{"fail name", fields{az.Type, "", az.TenantID, false, false, nil}, args{config}, true},
{"fail tenant id", fields{az.Type, az.Name, "", false, false, nil}, args{config}, true},
{"fail claims", fields{az.Type, az.Name, az.TenantID, false, false, badClaims}, args{config}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Azure{
Type: tt.fields.Type,
Name: tt.fields.Name,
TenantID: tt.fields.TenantID,
DisableCustomSANs: tt.fields.DisableCustomSANs,
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
Claims: tt.fields.Claims,
config: az.config,
}
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
t.Errorf("Azure.Init() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAzure_AuthorizeSign(t *testing.T) {
type fields struct {
Type string
Name string
DisableCustomSANs bool
DisableTrustOnFirstUse bool
Claims *Claims
claimer *Claimer
config *azureConfig
}
type args struct {
token string
}
tests := []struct {
name string
fields fields
args args
want []SignOption
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Azure{
Type: tt.fields.Type,
Name: tt.fields.Name,
DisableCustomSANs: tt.fields.DisableCustomSANs,
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
Claims: tt.fields.Claims,
claimer: tt.fields.claimer,
config: tt.fields.config,
}
got, err := p.AuthorizeSign(tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Azure.AuthorizeSign() = %v, want %v", got, tt.want)
}
})
}
}
func TestAzure_AuthorizeRenewal(t *testing.T) {
p1, err := generateAzure()
assert.FatalError(t, err)
p2, err := generateAzure()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
azure *Azure
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAzure_AuthorizeRevoke(t *testing.T) {
type fields struct {
Type string
Name string
DisableCustomSANs bool
DisableTrustOnFirstUse bool
Claims *Claims
claimer *Claimer
config *azureConfig
}
type args struct {
token string
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Azure{
Type: tt.fields.Type,
Name: tt.fields.Name,
DisableCustomSANs: tt.fields.DisableCustomSANs,
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
Claims: tt.fields.Claims,
claimer: tt.fields.claimer,
config: tt.fields.config,
}
if err := p.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -9,6 +9,7 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"time" "time"
@ -328,6 +329,99 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) {
return aws, srv, nil return aws, srv, nil
} }
func generateAzure() (*Azure, error) {
name, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
tenantID, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
jwk, err := generateJSONWebKey()
if err != nil {
return nil, err
}
return &Azure{
Type: "Azure",
Name: name,
TenantID: tenantID,
Claims: &globalProvisionerClaims,
claimer: claimer,
config: newAzureConfig(tenantID),
oidcConfig: openIDConfiguration{
Issuer: "https://sts.windows.net/" + tenantID + "/",
JWKSetURI: "https://login.microsoftonline.com/common/discovery/keys",
},
keyStore: &keyStore{
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour),
},
}, nil
}
func generateAzureWithServer() (*Azure, *httptest.Server, error) {
az, err := generateAzure()
if err != nil {
return nil, nil, err
}
writeJSON := func(w http.ResponseWriter, v interface{}) {
b, err := json.Marshal(v)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(b)
}
getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet {
var ret jose.JSONWebKeySet
for _, k := range ks.Keys {
ret.Keys = append(ret.Keys, k.Public())
}
return ret
}
issuer := "https://sts.windows.net/" + az.TenantID + "/"
srv := httptest.NewUnstartedServer(nil)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/error":
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
case "/" + az.TenantID + "/.well-known/openid-configuration":
writeJSON(w, openIDConfiguration{Issuer: issuer, JWKSetURI: srv.URL + "/jwks_uri"})
case "/random":
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(keySet))
case "/private":
writeJSON(w, az.keyStore.keySet)
case "/jwks_uri":
w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(az.keyStore.keySet))
case "/metadata/identity/oauth2/token":
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0])
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else {
writeJSON(w, azureIdentityToken{
AccessToken: tok,
})
}
default:
http.NotFound(w, r)
}
})
srv.Start()
az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration"
az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token"
return az, srv, nil
}
func generateCollection(nJWK, nOIDC int) (*Collection, error) { func generateCollection(nJWK, nOIDC int) (*Collection, error) {
col := NewCollection(testAudiences) col := NewCollection(testAudiences)
for i := 0; i < nJWK; i++ { for i := 0; i < nJWK; i++ {
@ -468,6 +562,35 @@ func generateAWSToken(sub, iss, aud, accountID, instanceID, privateIP, region st
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
)
if err != nil {
return "", err
}
claims := azurePayload{
Claims: jose.Claims{
Subject: sub,
Issuer: iss,
IssuedAt: jose.NewNumericDate(iat),
NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud},
},
AppID: "the-appid",
AppIDAcr: "the-appidacr",
IdentityProvider: "the-idp",
ObjectID: "the-oid",
TenantID: tenantID,
Version: "the-version",
XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine),
}
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) { func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
tok, err := jose.ParseSigned(token) tok, err := jose.ParseSigned(token)
if err != nil { if err != nil {