forked from TrueCloudLab/certificates
Require TenantID in azure, add some tests.
This commit is contained in:
parent
12937c6b75
commit
4c5fec06bf
3 changed files with 393 additions and 11 deletions
|
@ -15,8 +15,8 @@ import (
|
||||||
"github.com/smallstep/cli/jose"
|
"github.com/smallstep/cli/jose"
|
||||||
)
|
)
|
||||||
|
|
||||||
// azureOIDCDiscoveryURL is the default discovery url for Microsoft Azure tokens.
|
// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
|
||||||
const azureOIDCDiscoveryURL = "https://login.microsoftonline.com/common/.well-known/openid-configuration"
|
const azureOIDCBaseURL = "https://login.microsoftonline.com"
|
||||||
|
|
||||||
// azureIdentityTokenURL is the URL to get the identity token for an instance.
|
// azureIdentityTokenURL is the URL to get the identity token for an instance.
|
||||||
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F"
|
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F"
|
||||||
|
@ -33,9 +33,9 @@ type azureConfig struct {
|
||||||
identityTokenURL string
|
identityTokenURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAzureConfig() *azureConfig {
|
func newAzureConfig(tenantID string) *azureConfig {
|
||||||
return &azureConfig{
|
return &azureConfig{
|
||||||
oidcDiscoveryURL: azureOIDCDiscoveryURL,
|
oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration",
|
||||||
identityTokenURL: azureIdentityTokenURL,
|
identityTokenURL: azureIdentityTokenURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -77,6 +77,7 @@ type azurePayload struct {
|
||||||
type Azure struct {
|
type Azure struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
TenantID string `json:"tenantId"`
|
||||||
Subscriptions []string `json:"subscriptions"`
|
Subscriptions []string `json:"subscriptions"`
|
||||||
Audience string `json:"audience,omitempty"`
|
Audience string `json:"audience,omitempty"`
|
||||||
DisableCustomSANs bool `json:"disableCustomSANs"`
|
DisableCustomSANs bool `json:"disableCustomSANs"`
|
||||||
|
@ -90,7 +91,7 @@ type Azure struct {
|
||||||
|
|
||||||
// GetID returns the provisioner unique identifier.
|
// GetID returns the provisioner unique identifier.
|
||||||
func (p *Azure) GetID() string {
|
func (p *Azure) GetID() string {
|
||||||
return p.Audience
|
return p.TenantID
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTokenID returns the identifier of the token. The default value for Azure
|
// GetTokenID returns the identifier of the token. The default value for Azure
|
||||||
|
@ -176,16 +177,20 @@ func (p *Azure) Init(config Config) (err error) {
|
||||||
return errors.New("provisioner type cannot be empty")
|
return errors.New("provisioner type cannot be empty")
|
||||||
case p.Name == "":
|
case p.Name == "":
|
||||||
return errors.New("provisioner name cannot be empty")
|
return errors.New("provisioner name cannot be empty")
|
||||||
|
case p.TenantID == "":
|
||||||
|
return errors.New("provisioner tenantId cannot be empty")
|
||||||
case p.Audience == "": // use default audience
|
case p.Audience == "": // use default audience
|
||||||
p.Audience = azureDefaultAudience
|
p.Audience = azureDefaultAudience
|
||||||
}
|
}
|
||||||
|
// Initialize config
|
||||||
|
if err := p.assertConfig(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Update claims with global ones
|
// Update claims with global ones
|
||||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Initialize configuration
|
|
||||||
p.config = newAzureConfig()
|
|
||||||
|
|
||||||
// Decode and validate openid-configuration endpoint
|
// Decode and validate openid-configuration endpoint
|
||||||
if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
|
if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
|
||||||
|
@ -209,12 +214,15 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "error parsing token")
|
return nil, errors.Wrapf(err, "error parsing token")
|
||||||
}
|
}
|
||||||
|
if len(jwt.Headers) == 0 {
|
||||||
|
return nil, errors.New("error parsing token: header is missing")
|
||||||
|
}
|
||||||
|
|
||||||
var found bool
|
var found bool
|
||||||
var claims azurePayload
|
var claims azurePayload
|
||||||
keys := p.keyStore.Get(jwt.Headers[0].KeyID)
|
keys := p.keyStore.Get(jwt.Headers[0].KeyID)
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
if err := jwt.Claims(key, &claims); err == nil {
|
if err := jwt.Claims(key.Public(), &claims); err == nil {
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -225,12 +233,17 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
||||||
|
|
||||||
if err := claims.ValidateWithLeeway(jose.Expected{
|
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||||
Audience: []string{p.Audience},
|
Audience: []string{p.Audience},
|
||||||
Issuer: strings.Replace(p.oidcConfig.Issuer, "{tenantid}", claims.TenantID, 1),
|
Issuer: p.oidcConfig.Issuer,
|
||||||
Time: time.Now(),
|
Time: time.Now(),
|
||||||
}, 1*time.Minute); err != nil {
|
}, 1*time.Minute); err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to validate payload")
|
return nil, errors.Wrap(err, "failed to validate payload")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate TenantID
|
||||||
|
if claims.TenantID != p.TenantID {
|
||||||
|
return nil, errors.New("validation failed: invalid tenant id claim (tid)")
|
||||||
|
}
|
||||||
|
|
||||||
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
|
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
|
||||||
if len(re) == 0 {
|
if len(re) == 0 {
|
||||||
return nil, errors.Errorf("error parsing xms_mirid claim: %s", claims.XMSMirID)
|
return nil, errors.Errorf("error parsing xms_mirid claim: %s", claims.XMSMirID)
|
||||||
|
@ -247,7 +260,7 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
return nil, errors.Errorf("subscription %s is not valid", subscription)
|
return nil, errors.New("validation failed: invalid subscription id")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -287,6 +300,6 @@ func (p *Azure) assertConfig() error {
|
||||||
if p.config != nil {
|
if p.config != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
p.config = newAzureConfig()
|
p.config = newAzureConfig(p.TenantID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
246
authority/provisioner/azure_test.go
Normal file
246
authority/provisioner/azure_test.go
Normal file
|
@ -0,0 +1,246 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAzure_Getters(t *testing.T) {
|
||||||
|
p, err := generateAzure()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
if got := p.GetID(); got != p.TenantID {
|
||||||
|
t.Errorf("Azure.GetID() = %v, want %v", got, p.TenantID)
|
||||||
|
}
|
||||||
|
if got := p.GetName(); got != p.Name {
|
||||||
|
t.Errorf("Azure.GetName() = %v, want %v", got, p.Name)
|
||||||
|
}
|
||||||
|
if got := p.GetType(); got != TypeAzure {
|
||||||
|
t.Errorf("Azure.GetType() = %v, want %v", got, TypeAzure)
|
||||||
|
}
|
||||||
|
kid, key, ok := p.GetEncryptedKey()
|
||||||
|
if kid != "" || key != "" || ok == true {
|
||||||
|
t.Errorf("Azure.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||||
|
kid, key, ok, "", "", false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzure_GetTokenID(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Type string
|
||||||
|
Name string
|
||||||
|
DisableCustomSANs bool
|
||||||
|
DisableTrustOnFirstUse bool
|
||||||
|
Claims *Claims
|
||||||
|
claimer *Claimer
|
||||||
|
config *azureConfig
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Azure{
|
||||||
|
Type: tt.fields.Type,
|
||||||
|
Name: tt.fields.Name,
|
||||||
|
DisableCustomSANs: tt.fields.DisableCustomSANs,
|
||||||
|
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
|
||||||
|
Claims: tt.fields.Claims,
|
||||||
|
claimer: tt.fields.claimer,
|
||||||
|
config: tt.fields.config,
|
||||||
|
}
|
||||||
|
got, err := p.GetTokenID(tt.args.token)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Azure.GetTokenID() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("Azure.GetTokenID() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzure_Init(t *testing.T) {
|
||||||
|
az, srv, err := generateAzureWithServer()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
config := Config{
|
||||||
|
Claims: globalProvisionerClaims,
|
||||||
|
}
|
||||||
|
badClaims := &Claims{
|
||||||
|
DefaultTLSDur: &Duration{0},
|
||||||
|
}
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
Type string
|
||||||
|
Name string
|
||||||
|
TenantID string
|
||||||
|
DisableCustomSANs bool
|
||||||
|
DisableTrustOnFirstUse bool
|
||||||
|
Claims *Claims
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
config Config
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{az.Type, az.Name, az.TenantID, false, false, nil}, args{config}, false},
|
||||||
|
{"ok", fields{az.Type, az.Name, az.TenantID, true, false, nil}, args{config}, false},
|
||||||
|
{"ok", fields{az.Type, az.Name, az.TenantID, false, true, nil}, args{config}, false},
|
||||||
|
{"ok", fields{az.Type, az.Name, az.TenantID, true, true, nil}, args{config}, false},
|
||||||
|
{"fail type", fields{"", az.Name, az.TenantID, false, false, nil}, args{config}, true},
|
||||||
|
{"fail name", fields{az.Type, "", az.TenantID, false, false, nil}, args{config}, true},
|
||||||
|
{"fail tenant id", fields{az.Type, az.Name, "", false, false, nil}, args{config}, true},
|
||||||
|
{"fail claims", fields{az.Type, az.Name, az.TenantID, false, false, badClaims}, args{config}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Azure{
|
||||||
|
Type: tt.fields.Type,
|
||||||
|
Name: tt.fields.Name,
|
||||||
|
TenantID: tt.fields.TenantID,
|
||||||
|
DisableCustomSANs: tt.fields.DisableCustomSANs,
|
||||||
|
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
|
||||||
|
Claims: tt.fields.Claims,
|
||||||
|
config: az.config,
|
||||||
|
}
|
||||||
|
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Azure.Init() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Type string
|
||||||
|
Name string
|
||||||
|
DisableCustomSANs bool
|
||||||
|
DisableTrustOnFirstUse bool
|
||||||
|
Claims *Claims
|
||||||
|
claimer *Claimer
|
||||||
|
config *azureConfig
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want []SignOption
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Azure{
|
||||||
|
Type: tt.fields.Type,
|
||||||
|
Name: tt.fields.Name,
|
||||||
|
DisableCustomSANs: tt.fields.DisableCustomSANs,
|
||||||
|
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
|
||||||
|
Claims: tt.fields.Claims,
|
||||||
|
claimer: tt.fields.claimer,
|
||||||
|
config: tt.fields.config,
|
||||||
|
}
|
||||||
|
got, err := p.AuthorizeSign(tt.args.token)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Azure.AuthorizeSign() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzure_AuthorizeRenewal(t *testing.T) {
|
||||||
|
p1, err := generateAzure()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p2, err := generateAzure()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
// disable renewal
|
||||||
|
disable := true
|
||||||
|
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||||
|
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
cert *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
azure *Azure
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", p1, args{nil}, false},
|
||||||
|
{"fail", p2, args{nil}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.azure.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Azure.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzure_AuthorizeRevoke(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Type string
|
||||||
|
Name string
|
||||||
|
DisableCustomSANs bool
|
||||||
|
DisableTrustOnFirstUse bool
|
||||||
|
Claims *Claims
|
||||||
|
claimer *Claimer
|
||||||
|
config *azureConfig
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Azure{
|
||||||
|
Type: tt.fields.Type,
|
||||||
|
Name: tt.fields.Name,
|
||||||
|
DisableCustomSANs: tt.fields.DisableCustomSANs,
|
||||||
|
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
|
||||||
|
Claims: tt.fields.Claims,
|
||||||
|
claimer: tt.fields.claimer,
|
||||||
|
config: tt.fields.config,
|
||||||
|
}
|
||||||
|
if err := p.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Azure.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"time"
|
"time"
|
||||||
|
@ -328,6 +329,99 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) {
|
||||||
return aws, srv, nil
|
return aws, srv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateAzure() (*Azure, error) {
|
||||||
|
name, err := randutil.Alphanumeric(10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tenantID, err := randutil.Alphanumeric(10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jwk, err := generateJSONWebKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Azure{
|
||||||
|
Type: "Azure",
|
||||||
|
Name: name,
|
||||||
|
TenantID: tenantID,
|
||||||
|
Claims: &globalProvisionerClaims,
|
||||||
|
claimer: claimer,
|
||||||
|
config: newAzureConfig(tenantID),
|
||||||
|
oidcConfig: openIDConfiguration{
|
||||||
|
Issuer: "https://sts.windows.net/" + tenantID + "/",
|
||||||
|
JWKSetURI: "https://login.microsoftonline.com/common/discovery/keys",
|
||||||
|
},
|
||||||
|
keyStore: &keyStore{
|
||||||
|
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||||
|
expiry: time.Now().Add(24 * time.Hour),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
||||||
|
az, err := generateAzure()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
writeJSON := func(w http.ResponseWriter, v interface{}) {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Add("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(b)
|
||||||
|
}
|
||||||
|
getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet {
|
||||||
|
var ret jose.JSONWebKeySet
|
||||||
|
for _, k := range ks.Keys {
|
||||||
|
ret.Keys = append(ret.Keys, k.Public())
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
issuer := "https://sts.windows.net/" + az.TenantID + "/"
|
||||||
|
srv := httptest.NewUnstartedServer(nil)
|
||||||
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/error":
|
||||||
|
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||||
|
case "/" + az.TenantID + "/.well-known/openid-configuration":
|
||||||
|
writeJSON(w, openIDConfiguration{Issuer: issuer, JWKSetURI: srv.URL + "/jwks_uri"})
|
||||||
|
case "/random":
|
||||||
|
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||||
|
w.Header().Add("Cache-Control", "max-age=5")
|
||||||
|
writeJSON(w, getPublic(keySet))
|
||||||
|
case "/private":
|
||||||
|
writeJSON(w, az.keyStore.keySet)
|
||||||
|
case "/jwks_uri":
|
||||||
|
w.Header().Add("Cache-Control", "max-age=5")
|
||||||
|
writeJSON(w, getPublic(az.keyStore.keySet))
|
||||||
|
case "/metadata/identity/oauth2/token":
|
||||||
|
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0])
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
} else {
|
||||||
|
writeJSON(w, azureIdentityToken{
|
||||||
|
AccessToken: tok,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
srv.Start()
|
||||||
|
az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration"
|
||||||
|
az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token"
|
||||||
|
return az, srv, nil
|
||||||
|
}
|
||||||
|
|
||||||
func generateCollection(nJWK, nOIDC int) (*Collection, error) {
|
func generateCollection(nJWK, nOIDC int) (*Collection, error) {
|
||||||
col := NewCollection(testAudiences)
|
col := NewCollection(testAudiences)
|
||||||
for i := 0; i < nJWK; i++ {
|
for i := 0; i < nJWK; i++ {
|
||||||
|
@ -468,6 +562,35 @@ func generateAWSToken(sub, iss, aud, accountID, instanceID, privateIP, region st
|
||||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||||
|
sig, err := jose.NewSigner(
|
||||||
|
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||||
|
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := azurePayload{
|
||||||
|
Claims: jose.Claims{
|
||||||
|
Subject: sub,
|
||||||
|
Issuer: iss,
|
||||||
|
IssuedAt: jose.NewNumericDate(iat),
|
||||||
|
NotBefore: jose.NewNumericDate(iat),
|
||||||
|
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
|
||||||
|
Audience: []string{aud},
|
||||||
|
},
|
||||||
|
AppID: "the-appid",
|
||||||
|
AppIDAcr: "the-appidacr",
|
||||||
|
IdentityProvider: "the-idp",
|
||||||
|
ObjectID: "the-oid",
|
||||||
|
TenantID: tenantID,
|
||||||
|
Version: "the-version",
|
||||||
|
XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine),
|
||||||
|
}
|
||||||
|
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
|
}
|
||||||
|
|
||||||
func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
|
func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
|
||||||
tok, err := jose.ParseSigned(token)
|
tok, err := jose.ParseSigned(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in a new issue