distribution/vendor/github.com/Azure/azure-sdk-for-go/sdk/azidentity/azidentity.go

166 lines
5.9 KiB
Go
Raw Normal View History

//go:build go1.18
// +build go1.18
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package azidentity
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/url"
"os"
"regexp"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
)
const (
azureAuthorityHost = "AZURE_AUTHORITY_HOST"
azureClientCertificatePassword = "AZURE_CLIENT_CERTIFICATE_PASSWORD"
azureClientCertificatePath = "AZURE_CLIENT_CERTIFICATE_PATH"
azureClientID = "AZURE_CLIENT_ID"
azureClientSecret = "AZURE_CLIENT_SECRET"
azurePassword = "AZURE_PASSWORD"
azureRegionalAuthorityName = "AZURE_REGIONAL_AUTHORITY_NAME"
azureTenantID = "AZURE_TENANT_ID"
azureUsername = "AZURE_USERNAME"
organizationsTenantID = "organizations"
developerSignOnClientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
defaultSuffix = "/.default"
tenantIDValidationErr = "invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names"
)
func getConfidentialClient(clientID, tenantID string, cred confidential.Credential, co *azcore.ClientOptions, additionalOpts ...confidential.Option) (confidential.Client, error) {
if !validTenantID(tenantID) {
return confidential.Client{}, errors.New(tenantIDValidationErr)
}
authorityHost, err := setAuthorityHost(co.Cloud)
if err != nil {
return confidential.Client{}, err
}
o := []confidential.Option{
confidential.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)),
confidential.WithHTTPClient(newPipelineAdapter(co)),
}
o = append(o, additionalOpts...)
return confidential.New(clientID, cred, o...)
}
func getPublicClient(clientID, tenantID string, co *azcore.ClientOptions) (public.Client, error) {
if !validTenantID(tenantID) {
return public.Client{}, errors.New(tenantIDValidationErr)
}
authorityHost, err := setAuthorityHost(co.Cloud)
if err != nil {
return public.Client{}, err
}
return public.New(clientID,
public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
public.WithHTTPClient(newPipelineAdapter(co)),
)
}
// setAuthorityHost initializes the authority host for credentials. Precedence is:
// 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user
// 2. value of AZURE_AUTHORITY_HOST
// 3. default: Azure Public Cloud
func setAuthorityHost(cc cloud.Configuration) (string, error) {
host := cc.ActiveDirectoryAuthorityHost
if host == "" {
if len(cc.Services) > 0 {
return "", errors.New("missing ActiveDirectoryAuthorityHost for specified cloud")
}
host = cloud.AzurePublic.ActiveDirectoryAuthorityHost
if envAuthorityHost := os.Getenv(azureAuthorityHost); envAuthorityHost != "" {
host = envAuthorityHost
}
}
u, err := url.Parse(host)
if err != nil {
return "", err
}
if u.Scheme != "https" {
return "", errors.New("cannot use an authority host without https")
}
return host, nil
}
// validTenantID return true is it receives a valid tenantID, returns false otherwise
func validTenantID(tenantID string) bool {
match, err := regexp.MatchString("^[0-9a-zA-Z-.]+$", tenantID)
if err != nil {
return false
}
return match
}
func newPipelineAdapter(opts *azcore.ClientOptions) pipelineAdapter {
pl := runtime.NewPipeline(component, version, runtime.PipelineOptions{}, opts)
return pipelineAdapter{pl: pl}
}
type pipelineAdapter struct {
pl runtime.Pipeline
}
func (p pipelineAdapter) CloseIdleConnections() {
// do nothing
}
func (p pipelineAdapter) Do(r *http.Request) (*http.Response, error) {
req, err := runtime.NewRequest(r.Context(), r.Method, r.URL.String())
if err != nil {
return nil, err
}
if r.Body != nil && r.Body != http.NoBody {
// create a rewindable body from the existing body as required
var body io.ReadSeekCloser
if rsc, ok := r.Body.(io.ReadSeekCloser); ok {
body = rsc
} else {
b, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
body = streaming.NopCloser(bytes.NewReader(b))
}
err = req.SetBody(body, r.Header.Get("Content-Type"))
if err != nil {
return nil, err
}
}
resp, err := p.pl.Do(req)
if err != nil {
return nil, err
}
return resp, err
}
// enables fakes for test scenarios
type confidentialClient interface {
AcquireTokenSilent(ctx context.Context, scopes []string, options ...confidential.AcquireSilentOption) (confidential.AuthResult, error)
AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, options ...confidential.AcquireByAuthCodeOption) (confidential.AuthResult, error)
AcquireTokenByCredential(ctx context.Context, scopes []string, options ...confidential.AcquireByCredentialOption) (confidential.AuthResult, error)
}
// enables fakes for test scenarios
type publicClient interface {
AcquireTokenSilent(ctx context.Context, scopes []string, options ...public.AcquireSilentOption) (public.AuthResult, error)
AcquireTokenByUsernamePassword(ctx context.Context, scopes []string, username string, password string, options ...public.AcquireByUsernamePasswordOption) (public.AuthResult, error)
AcquireTokenByDeviceCode(ctx context.Context, scopes []string, options ...public.AcquireByDeviceCodeOption) (public.DeviceCode, error)
AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, options ...public.AcquireByAuthCodeOption) (public.AuthResult, error)
AcquireTokenInteractive(ctx context.Context, scopes []string, options ...public.AcquireInteractiveOption) (public.AuthResult, error)
}