Require TenantID in azure, add some tests.
This commit is contained in:
parent
12937c6b75
commit
4c5fec06bf
3 changed files with 393 additions and 11 deletions
|
@ -15,8 +15,8 @@ import (
|
|||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// azureOIDCDiscoveryURL is the default discovery url for Microsoft Azure tokens.
|
||||
const azureOIDCDiscoveryURL = "https://login.microsoftonline.com/common/.well-known/openid-configuration"
|
||||
// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
|
||||
const azureOIDCBaseURL = "https://login.microsoftonline.com"
|
||||
|
||||
// azureIdentityTokenURL is the URL to get the identity token for an instance.
|
||||
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F"
|
||||
|
@ -33,9 +33,9 @@ type azureConfig struct {
|
|||
identityTokenURL string
|
||||
}
|
||||
|
||||
func newAzureConfig() *azureConfig {
|
||||
func newAzureConfig(tenantID string) *azureConfig {
|
||||
return &azureConfig{
|
||||
oidcDiscoveryURL: azureOIDCDiscoveryURL,
|
||||
oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration",
|
||||
identityTokenURL: azureIdentityTokenURL,
|
||||
}
|
||||
}
|
||||
|
@ -77,6 +77,7 @@ type azurePayload struct {
|
|||
type Azure struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
TenantID string `json:"tenantId"`
|
||||
Subscriptions []string `json:"subscriptions"`
|
||||
Audience string `json:"audience,omitempty"`
|
||||
DisableCustomSANs bool `json:"disableCustomSANs"`
|
||||
|
@ -90,7 +91,7 @@ type Azure struct {
|
|||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
func (p *Azure) GetID() string {
|
||||
return p.Audience
|
||||
return p.TenantID
|
||||
}
|
||||
|
||||
// GetTokenID returns the identifier of the token. The default value for Azure
|
||||
|
@ -176,16 +177,20 @@ func (p *Azure) Init(config Config) (err error) {
|
|||
return errors.New("provisioner type cannot be empty")
|
||||
case p.Name == "":
|
||||
return errors.New("provisioner name cannot be empty")
|
||||
case p.TenantID == "":
|
||||
return errors.New("provisioner tenantId cannot be empty")
|
||||
case p.Audience == "": // use default audience
|
||||
p.Audience = azureDefaultAudience
|
||||
}
|
||||
// Initialize config
|
||||
if err := p.assertConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
// Initialize configuration
|
||||
p.config = newAzureConfig()
|
||||
|
||||
// Decode and validate openid-configuration endpoint
|
||||
if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
|
||||
|
@ -209,12 +214,15 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
|||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
}
|
||||
if len(jwt.Headers) == 0 {
|
||||
return nil, errors.New("error parsing token: header is missing")
|
||||
}
|
||||
|
||||
var found bool
|
||||
var claims azurePayload
|
||||
keys := p.keyStore.Get(jwt.Headers[0].KeyID)
|
||||
for _, key := range keys {
|
||||
if err := jwt.Claims(key, &claims); err == nil {
|
||||
if err := jwt.Claims(key.Public(), &claims); err == nil {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
|
@ -225,12 +233,17 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
|||
|
||||
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||
Audience: []string{p.Audience},
|
||||
Issuer: strings.Replace(p.oidcConfig.Issuer, "{tenantid}", claims.TenantID, 1),
|
||||
Issuer: p.oidcConfig.Issuer,
|
||||
Time: time.Now(),
|
||||
}, 1*time.Minute); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to validate payload")
|
||||
}
|
||||
|
||||
// Validate TenantID
|
||||
if claims.TenantID != p.TenantID {
|
||||
return nil, errors.New("validation failed: invalid tenant id claim (tid)")
|
||||
}
|
||||
|
||||
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
|
||||
if len(re) == 0 {
|
||||
return nil, errors.Errorf("error parsing xms_mirid claim: %s", claims.XMSMirID)
|
||||
|
@ -247,7 +260,7 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errors.Errorf("subscription %s is not valid", subscription)
|
||||
return nil, errors.New("validation failed: invalid subscription id")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -287,6 +300,6 @@ func (p *Azure) assertConfig() error {
|
|||
if p.config != nil {
|
||||
return nil
|
||||
}
|
||||
p.config = newAzureConfig()
|
||||
p.config = newAzureConfig(p.TenantID)
|
||||
return nil
|
||||
}
|
||||
|
|
246
authority/provisioner/azure_test.go
Normal file
246
authority/provisioner/azure_test.go
Normal file
|
@ -0,0 +1,246 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func TestAzure_Getters(t *testing.T) {
|
||||
p, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
if got := p.GetID(); got != p.TenantID {
|
||||
t.Errorf("Azure.GetID() = %v, want %v", got, p.TenantID)
|
||||
}
|
||||
if got := p.GetName(); got != p.Name {
|
||||
t.Errorf("Azure.GetName() = %v, want %v", got, p.Name)
|
||||
}
|
||||
if got := p.GetType(); got != TypeAzure {
|
||||
t.Errorf("Azure.GetType() = %v, want %v", got, TypeAzure)
|
||||
}
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
if kid != "" || key != "" || ok == true {
|
||||
t.Errorf("Azure.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||
kid, key, ok, "", "", false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzure_GetTokenID(t *testing.T) {
|
||||
type fields struct {
|
||||
Type string
|
||||
Name string
|
||||
DisableCustomSANs bool
|
||||
DisableTrustOnFirstUse bool
|
||||
Claims *Claims
|
||||
claimer *Claimer
|
||||
config *azureConfig
|
||||
}
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Azure{
|
||||
Type: tt.fields.Type,
|
||||
Name: tt.fields.Name,
|
||||
DisableCustomSANs: tt.fields.DisableCustomSANs,
|
||||
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
|
||||
Claims: tt.fields.Claims,
|
||||
claimer: tt.fields.claimer,
|
||||
config: tt.fields.config,
|
||||
}
|
||||
got, err := p.GetTokenID(tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.GetTokenID() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("Azure.GetTokenID() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzure_Init(t *testing.T) {
|
||||
az, srv, err := generateAzureWithServer()
|
||||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
config := Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
}
|
||||
badClaims := &Claims{
|
||||
DefaultTLSDur: &Duration{0},
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
Type string
|
||||
Name string
|
||||
TenantID string
|
||||
DisableCustomSANs bool
|
||||
DisableTrustOnFirstUse bool
|
||||
Claims *Claims
|
||||
}
|
||||
type args struct {
|
||||
config Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{az.Type, az.Name, az.TenantID, false, false, nil}, args{config}, false},
|
||||
{"ok", fields{az.Type, az.Name, az.TenantID, true, false, nil}, args{config}, false},
|
||||
{"ok", fields{az.Type, az.Name, az.TenantID, false, true, nil}, args{config}, false},
|
||||
{"ok", fields{az.Type, az.Name, az.TenantID, true, true, nil}, args{config}, false},
|
||||
{"fail type", fields{"", az.Name, az.TenantID, false, false, nil}, args{config}, true},
|
||||
{"fail name", fields{az.Type, "", az.TenantID, false, false, nil}, args{config}, true},
|
||||
{"fail tenant id", fields{az.Type, az.Name, "", false, false, nil}, args{config}, true},
|
||||
{"fail claims", fields{az.Type, az.Name, az.TenantID, false, false, badClaims}, args{config}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Azure{
|
||||
Type: tt.fields.Type,
|
||||
Name: tt.fields.Name,
|
||||
TenantID: tt.fields.TenantID,
|
||||
DisableCustomSANs: tt.fields.DisableCustomSANs,
|
||||
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
|
||||
Claims: tt.fields.Claims,
|
||||
config: az.config,
|
||||
}
|
||||
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.Init() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzure_AuthorizeSign(t *testing.T) {
|
||||
type fields struct {
|
||||
Type string
|
||||
Name string
|
||||
DisableCustomSANs bool
|
||||
DisableTrustOnFirstUse bool
|
||||
Claims *Claims
|
||||
claimer *Claimer
|
||||
config *azureConfig
|
||||
}
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []SignOption
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Azure{
|
||||
Type: tt.fields.Type,
|
||||
Name: tt.fields.Name,
|
||||
DisableCustomSANs: tt.fields.DisableCustomSANs,
|
||||
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
|
||||
Claims: tt.fields.Claims,
|
||||
claimer: tt.fields.claimer,
|
||||
config: tt.fields.config,
|
||||
}
|
||||
got, err := p.AuthorizeSign(tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Azure.AuthorizeSign() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzure_AuthorizeRenewal(t *testing.T) {
|
||||
p1, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
azure *Azure
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.azure.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzure_AuthorizeRevoke(t *testing.T) {
|
||||
type fields struct {
|
||||
Type string
|
||||
Name string
|
||||
DisableCustomSANs bool
|
||||
DisableTrustOnFirstUse bool
|
||||
Claims *Claims
|
||||
claimer *Claimer
|
||||
config *azureConfig
|
||||
}
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Azure{
|
||||
Type: tt.fields.Type,
|
||||
Name: tt.fields.Name,
|
||||
DisableCustomSANs: tt.fields.DisableCustomSANs,
|
||||
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
|
||||
Claims: tt.fields.Claims,
|
||||
claimer: tt.fields.claimer,
|
||||
config: tt.fields.config,
|
||||
}
|
||||
if err := p.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -9,6 +9,7 @@ import (
|
|||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"time"
|
||||
|
@ -328,6 +329,99 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) {
|
|||
return aws, srv, nil
|
||||
}
|
||||
|
||||
func generateAzure() (*Azure, error) {
|
||||
name, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tenantID, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Azure{
|
||||
Type: "Azure",
|
||||
Name: name,
|
||||
TenantID: tenantID,
|
||||
Claims: &globalProvisionerClaims,
|
||||
claimer: claimer,
|
||||
config: newAzureConfig(tenantID),
|
||||
oidcConfig: openIDConfiguration{
|
||||
Issuer: "https://sts.windows.net/" + tenantID + "/",
|
||||
JWKSetURI: "https://login.microsoftonline.com/common/discovery/keys",
|
||||
},
|
||||
keyStore: &keyStore{
|
||||
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||
expiry: time.Now().Add(24 * time.Hour),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
||||
az, err := generateAzure()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
writeJSON := func(w http.ResponseWriter, v interface{}) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(b)
|
||||
}
|
||||
getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet {
|
||||
var ret jose.JSONWebKeySet
|
||||
for _, k := range ks.Keys {
|
||||
ret.Keys = append(ret.Keys, k.Public())
|
||||
}
|
||||
return ret
|
||||
}
|
||||
issuer := "https://sts.windows.net/" + az.TenantID + "/"
|
||||
srv := httptest.NewUnstartedServer(nil)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/error":
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
case "/" + az.TenantID + "/.well-known/openid-configuration":
|
||||
writeJSON(w, openIDConfiguration{Issuer: issuer, JWKSetURI: srv.URL + "/jwks_uri"})
|
||||
case "/random":
|
||||
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||
w.Header().Add("Cache-Control", "max-age=5")
|
||||
writeJSON(w, getPublic(keySet))
|
||||
case "/private":
|
||||
writeJSON(w, az.keyStore.keySet)
|
||||
case "/jwks_uri":
|
||||
w.Header().Add("Cache-Control", "max-age=5")
|
||||
writeJSON(w, getPublic(az.keyStore.keySet))
|
||||
case "/metadata/identity/oauth2/token":
|
||||
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0])
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
} else {
|
||||
writeJSON(w, azureIdentityToken{
|
||||
AccessToken: tok,
|
||||
})
|
||||
}
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
srv.Start()
|
||||
az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration"
|
||||
az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token"
|
||||
return az, srv, nil
|
||||
}
|
||||
|
||||
func generateCollection(nJWK, nOIDC int) (*Collection, error) {
|
||||
col := NewCollection(testAudiences)
|
||||
for i := 0; i < nJWK; i++ {
|
||||
|
@ -468,6 +562,35 @@ func generateAWSToken(sub, iss, aud, accountID, instanceID, privateIP, region st
|
|||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
}
|
||||
|
||||
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||
sig, err := jose.NewSigner(
|
||||
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims := azurePayload{
|
||||
Claims: jose.Claims{
|
||||
Subject: sub,
|
||||
Issuer: iss,
|
||||
IssuedAt: jose.NewNumericDate(iat),
|
||||
NotBefore: jose.NewNumericDate(iat),
|
||||
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
|
||||
Audience: []string{aud},
|
||||
},
|
||||
AppID: "the-appid",
|
||||
AppIDAcr: "the-appidacr",
|
||||
IdentityProvider: "the-idp",
|
||||
ObjectID: "the-oid",
|
||||
TenantID: tenantID,
|
||||
Version: "the-version",
|
||||
XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine),
|
||||
}
|
||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
}
|
||||
|
||||
func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
|
||||
tok, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in a new issue